diff --git a/src/rapidata/rapidata_client/benchmark/participant/_participant.py b/src/rapidata/rapidata_client/benchmark/participant/_participant.py index 1247043f8..999aff00a 100644 --- a/src/rapidata/rapidata_client/benchmark/participant/_participant.py +++ b/src/rapidata/rapidata_client/benchmark/participant/_participant.py @@ -2,7 +2,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import time -from tqdm import tqdm +from tqdm.auto import tqdm from rapidata.rapidata_client.config import logger from rapidata.rapidata_client.config.rapidata_config import rapidata_config diff --git a/src/rapidata/rapidata_client/config/order_config.py b/src/rapidata/rapidata_client/config/order_config.py deleted file mode 100644 index c9990507d..000000000 --- a/src/rapidata/rapidata_client/config/order_config.py +++ /dev/null @@ -1,14 +0,0 @@ -from pydantic import BaseModel, Field - - -class OrderConfig(BaseModel): - """ - Holds the configuration for the order process. - - Attributes: - minOrderDatapointsForValidation (int): The minimum number of datapoints required so that an automatic validationset gets created if no recommended was found. Defaults to 50. - autoValidationSetSize (int): The maximum size of the auto-generated validation set. Defaults to 20. - """ - - autoValidationSetCreation: bool = Field(default=True) - minOrderDatapointsForValidation: int = Field(default=20) diff --git a/src/rapidata/rapidata_client/config/rapidata_config.py b/src/rapidata/rapidata_client/config/rapidata_config.py index 02de167fd..1d2fedfce 100644 --- a/src/rapidata/rapidata_client/config/rapidata_config.py +++ b/src/rapidata/rapidata_client/config/rapidata_config.py @@ -1,7 +1,6 @@ from pydantic import BaseModel, Field from rapidata.rapidata_client.config.logging_config import LoggingConfig -from rapidata.rapidata_client.config.order_config import OrderConfig from rapidata.rapidata_client.config.upload_config import UploadConfig @@ -15,8 +14,6 @@ class RapidataConfig(BaseModel): enableBetaFeatures (bool): Whether to enable beta features. Defaults to False. upload (UploadConfig): The configuration for the upload process. Such as the maximum number of worker threads for processing media paths and the maximum number of retries for failed uploads. - order (OrderConfig): The configuration for the order process. - Such as the minimum number of datapoints required so that an automatic validationset gets created if no recommended was found. logging (LoggingConfig): The configuration for the logging process. Such as the logging level and the logging file. @@ -29,7 +26,6 @@ class RapidataConfig(BaseModel): enableBetaFeatures: bool = False upload: UploadConfig = Field(default_factory=UploadConfig) - order: OrderConfig = Field(default_factory=OrderConfig) logging: LoggingConfig = Field(default_factory=LoggingConfig) diff --git a/src/rapidata/rapidata_client/config/upload_config.py b/src/rapidata/rapidata_client/config/upload_config.py index b2981c390..1b46e55b0 100644 --- a/src/rapidata/rapidata_client/config/upload_config.py +++ b/src/rapidata/rapidata_client/config/upload_config.py @@ -1,6 +1,6 @@ from pathlib import Path import shutil -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from rapidata.rapidata_client.config import logger @@ -11,19 +11,29 @@ class UploadConfig(BaseModel): Attributes: maxWorkers (int): The maximum number of worker threads for concurrent uploads. Defaults to 25. maxRetries (int): The maximum number of retries for failed uploads. Defaults to 3. - cacheUploads (bool): Enable/disable upload caching. Defaults to True. + cacheToDisk (bool): Enable disk-based caching for file uploads. If False, uses in-memory cache only. Defaults to True. + Note: URL assets are always cached in-memory regardless of this setting. + Caching cannot be disabled entirely as it's required for the two-step upload flow. cacheTimeout (float): Cache operation timeout in seconds. Defaults to 0.1. cacheLocation (Path): Directory for cache storage. Defaults to ~/.cache/rapidata/upload_cache. - This is immutable + This is immutable. Only used for file uploads when cacheToDisk=True. cacheShards (int): Number of cache shards for parallel access. Defaults to 128. Higher values improve concurrency but increase file handles. Must be positive. - This is immutable + This is immutable. Only used for file uploads when cacheToDisk=True. + enableBatchUpload (bool): Enable batch URL uploading (two-step process). Defaults to True. + batchSize (int): Number of URLs per batch (10-500). Defaults to 100. + batchPollInterval (float): Polling interval in seconds. Defaults to 0.5. """ + model_config = ConfigDict(validate_assignment=True) + maxWorkers: int = Field(default=25) maxRetries: int = Field(default=3) - cacheUploads: bool = Field(default=True) - cacheTimeout: float = Field(default=0.1) + cacheToDisk: bool = Field( + default=True, + description="Enable disk-based caching for file uploads. URLs are always cached in-memory.", + ) + cacheTimeout: float = Field(default=1) cacheLocation: Path = Field( default=Path.home() / ".cache" / "rapidata" / "upload_cache", frozen=True, @@ -32,6 +42,14 @@ class UploadConfig(BaseModel): default=128, frozen=True, ) + batchSize: int = Field( + default=1000, + description="Number of URLs per batch (100-5000)", + ) + batchPollInterval: float = Field( + default=0.5, + description="Polling interval in seconds", + ) @field_validator("maxWorkers") @classmethod @@ -54,6 +72,13 @@ def validate_cache_shards(cls, v: int) -> int: ) return v + @field_validator("batchSize") + @classmethod + def validate_batch_size(cls, v: int) -> int: + if v < 100: + raise ValueError("batchSize must be at least 100") + return v + def __init__(self, **kwargs): super().__init__(**kwargs) self._migrate_cache() diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py new file mode 100644 index 000000000..8d404f500 --- /dev/null +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +import re +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Callable, TYPE_CHECKING + +from tqdm.auto import tqdm + +from rapidata.rapidata_client.config import logger, rapidata_config +from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader +from rapidata.rapidata_client.datapoints._batch_asset_uploader import ( + BatchAssetUploader, +) +from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload +from rapidata.rapidata_client.datapoints._single_flight_cache import SingleFlightCache + +if TYPE_CHECKING: + from rapidata.rapidata_client.datapoints._datapoint import Datapoint + from rapidata.service.openapi_service import OpenAPIService + + +def extract_assets_from_datapoint(datapoint: Datapoint) -> set[str]: + """ + Extract all assets from a single datapoint. + + Args: + datapoint: The datapoint to extract assets from. + + Returns: + Set of asset identifiers (URLs or file paths). + """ + assets: set[str] = set() + + # Main asset(s) + if isinstance(datapoint.asset, list): + assets.update(datapoint.asset) + else: + assets.add(datapoint.asset) + + # Context asset + if datapoint.media_context: + assets.add(datapoint.media_context) + + return assets + + +class AssetUploadOrchestrator: + """ + Orchestrates Step 1/2: Upload ALL assets from ALL datapoints. + + This class extracts all unique assets, separates URLs from files, + filters cached assets, and uploads uncached assets using batch + upload for URLs and parallel upload for files. + + Returns list of failed uploads for any assets that fail. + """ + + def __init__(self, openapi_service: OpenAPIService) -> None: + self.asset_uploader = AssetUploader(openapi_service) + self.batch_uploader = BatchAssetUploader(openapi_service) + + def upload_all_assets( + self, + assets: set[str] | list[str], + asset_completion_callback: Callable[[list[str]], None] | None = None, + ) -> list[FailedUpload[str]]: + """ + Step 1/2: Upload ALL assets. + Returns list of failed uploads for any assets that fail. + + Args: + assets: Set or list of asset identifiers (URLs or file paths) to upload. + asset_completion_callback: Optional callback to notify when assets complete (called with list of successful assets). + + Returns: + List of FailedUpload instances for any assets that failed. + """ + # 1. Validate assets + all_assets = set(assets) if isinstance(assets, list) else assets + logger.info(f"Uploading {len(all_assets)} unique asset(s)") + + if not all_assets: + logger.debug("No assets to upload") + return [] + + # 2. Separate and filter assets + urls, files = self._separate_urls_and_files(all_assets) + uncached_urls, uncached_files = self._filter_and_log_cached_assets(urls, files) + + # Notify callback about cached assets (already complete) + cached_assets = [] + cached_assets.extend([url for url in urls if url not in uncached_urls]) + cached_assets.extend([file for file in files if file not in uncached_files]) + + if cached_assets and asset_completion_callback: + logger.debug(f"Notifying callback of {len(cached_assets)} cached asset(s)") + asset_completion_callback(cached_assets) + + if len(uncached_urls) + len(uncached_files) == 0: + logger.debug("All assets cached, nothing to upload") + return [] + + # 3. Perform uploads + failed_uploads = self._perform_uploads( + uncached_urls, uncached_files, asset_completion_callback + ) + + # 4. Report results + self._log_upload_results(failed_uploads) + return failed_uploads + + def _separate_urls_and_files(self, assets: set[str]) -> tuple[list[str], list[str]]: + """ + Separate assets into URLs and file paths. + + Args: + assets: Set of asset identifiers. + + Returns: + Tuple of (urls, files). + """ + urls = [a for a in assets if re.match(r"^https?://", a)] + files = [a for a in assets if not re.match(r"^https?://", a)] + logger.debug(f"Asset breakdown: {len(urls)} URL(s), {len(files)} file(s)") + return urls, files + + def _filter_and_log_cached_assets( + self, urls: list[str], files: list[str] + ) -> tuple[list[str], list[str]]: + """ + Filter out cached assets and log statistics. + + Args: + urls: List of URL assets. + files: List of file assets. + + Returns: + Tuple of (uncached_urls, uncached_files). + """ + uncached_urls = self._filter_uncached(urls, self.asset_uploader._url_cache) + uncached_files = self._filter_uncached( + files, self.asset_uploader._get_file_cache() + ) + + logger.info( + f"Assets to upload: {len(uncached_urls)} URL(s), {len(uncached_files)} file(s) " + f"(skipped {len(urls) - len(uncached_urls)} cached URL(s), " + f"{len(files) - len(uncached_files)} cached file(s))" + ) + + return uncached_urls, uncached_files + + def _perform_uploads( + self, + uncached_urls: list[str], + uncached_files: list[str], + asset_completion_callback: Callable[[list[str]], None] | None, + ) -> list[FailedUpload[str]]: + """ + Execute asset uploads with progress tracking. + + Args: + uncached_urls: URLs to upload. + uncached_files: Files to upload. + asset_completion_callback: Callback for completed assets. + + Returns: + List of failed uploads. + """ + failed_uploads: list[FailedUpload[str]] = [] + total = len(uncached_urls) + len(uncached_files) + + with tqdm( + total=total, + desc="Step 1/2: Uploading assets", + position=0, + disable=rapidata_config.logging.silent_mode, + leave=True, + ) as pbar: + # Upload URLs + if uncached_urls: + url_failures = self._upload_urls_with_progress( + uncached_urls, pbar, asset_completion_callback + ) + failed_uploads.extend(url_failures) + else: + logger.debug("No uncached URLs to upload") + + # Upload files + if uncached_files: + file_failures = self._upload_files_with_progress( + uncached_files, pbar, asset_completion_callback + ) + failed_uploads.extend(file_failures) + else: + logger.debug("No uncached files to upload") + + return failed_uploads + + def _upload_urls_with_progress( + self, + urls: list[str], + pbar: tqdm, + completion_callback: Callable[[list[str]], None] | None, + ) -> list[FailedUpload[str]]: + """Upload URLs with progress bar updates.""" + logger.debug(f"Batch uploading {len(urls)} URL(s)") + + def update_progress(n: int) -> None: + pbar.update(n) + + return self.batch_uploader.batch_upload_urls( + urls, + progress_callback=update_progress, + completion_callback=completion_callback, + ) + + def _upload_files_with_progress( + self, + files: list[str], + pbar: tqdm, + completion_callback: Callable[[list[str]], None] | None, + ) -> list[FailedUpload[str]]: + """Upload files with progress bar updates.""" + logger.debug(f"Parallel uploading {len(files)} file(s)") + + def update_progress() -> None: + pbar.update(1) + + return self._upload_files_parallel( + files, + progress_callback=update_progress, + completion_callback=completion_callback, + ) + + def _log_upload_results(self, failed_uploads: list[FailedUpload[str]]) -> None: + """Log the results of asset uploads.""" + if failed_uploads: + logger.warning(f"Step 1/2: {len(failed_uploads)} asset(s) failed to upload") + else: + logger.info("Step 1/2: All assets uploaded successfully") + + def _filter_uncached( + self, assets: list[str], cache: SingleFlightCache + ) -> list[str]: + """Filter out assets that are already cached.""" + uncached = [] + for asset in assets: + try: + # Try to get cache key using centralized methods + if re.match(r"^https?://", asset): + cache_key = self.asset_uploader.get_url_cache_key(asset) + else: + cache_key = self.asset_uploader.get_file_cache_key(asset) + + # Check if in cache + if cache_key not in cache.get_storage(): + uncached.append(asset) + except Exception as e: + # If cache check fails, include in upload list + logger.warning(f"Cache check failed for {asset}: {e}") + uncached.append(asset) + + return uncached + + def _upload_files_parallel( + self, + files: list[str], + progress_callback: Callable[[], None] | None = None, + completion_callback: Callable[[list[str]], None] | None = None, + ) -> list[FailedUpload[str]]: + """ + Upload files in parallel using ThreadPoolExecutor. + + Args: + files: List of file paths to upload. + progress_callback: Optional callback to report progress (called once per completed file). + completion_callback: Optional callback to notify when files complete (called with list of successful files). + + Returns: + List of FailedUpload instances for any files that failed. + """ + failed_uploads: list[FailedUpload[str]] = [] + + def upload_single_file(file_path: str) -> FailedUpload[str] | None: + """Upload a single file and return FailedUpload if it fails.""" + try: + self.asset_uploader.upload_asset(file_path) + return None + except Exception as e: + logger.warning(f"Failed to upload file {file_path}: {e}") + return FailedUpload.from_exception(file_path, e) + + with ThreadPoolExecutor( + max_workers=rapidata_config.upload.maxWorkers + ) as executor: + futures = { + executor.submit(upload_single_file, file_path): file_path + for file_path in files + } + + for future in as_completed(futures): + file_path = futures[future] + result = future.result() + if result is not None: + failed_uploads.append(result) + else: + # File uploaded successfully, notify callback + if completion_callback: + completion_callback([file_path]) + + if progress_callback: + progress_callback() + + return failed_uploads diff --git a/src/rapidata/rapidata_client/datapoints/_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_asset_uploader.py index 5be916f79..d24e35d90 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_uploader.py @@ -2,6 +2,7 @@ import re import os +import threading from typing import TYPE_CHECKING from rapidata.service.openapi_service import OpenAPIService @@ -14,20 +15,51 @@ class AssetUploader: - _file_cache: SingleFlightCache = SingleFlightCache( - "File cache", - storage=FanoutCache( - rapidata_config.upload.cacheLocation, - shards=rapidata_config.upload.cacheShards, - timeout=rapidata_config.upload.cacheTimeout, - ), - ) - _url_cache: SingleFlightCache = SingleFlightCache("URL cache") - - def __init__(self, openapi_service: OpenAPIService): + # Class-level caches shared across all instances + # URL cache: Always in-memory (URLs are lightweight, no benefit to disk caching) + # File cache: Lazily initialized based on cacheToDisk config + _url_cache: SingleFlightCache = SingleFlightCache("URL cache", storage={}) + _file_cache: SingleFlightCache | None = None + _file_cache_lock: threading.Lock = threading.Lock() + + @classmethod + def _get_file_cache(cls) -> SingleFlightCache: + """ + Get or create the file cache based on current config. + + Uses lazy initialization to respect cacheToDisk setting at runtime. + Thread-safe with double-checked locking pattern. + + Returns: + Configured file cache (disk or memory based on cacheToDisk). + """ + if cls._file_cache is not None: + return cls._file_cache + + with cls._file_cache_lock: + # Double-check after acquiring lock + if cls._file_cache is not None: + return cls._file_cache + + # Create cache storage based on current config + if rapidata_config.upload.cacheToDisk: + storage: dict[str, str] | FanoutCache = FanoutCache( + rapidata_config.upload.cacheLocation, + shards=rapidata_config.upload.cacheShards, + timeout=rapidata_config.upload.cacheTimeout, + ) + logger.debug("Initialized file cache with disk storage") + else: + storage = {} + logger.debug("Initialized file cache with in-memory storage") + + cls._file_cache = SingleFlightCache("File cache", storage=storage) + return cls._file_cache + + def __init__(self, openapi_service: OpenAPIService) -> None: self.openapi_service = openapi_service - def _get_file_cache_key(self, asset: str) -> str: + def get_file_cache_key(self, asset: str) -> str: """Generate cache key for a file, including environment.""" env = self.openapi_service.environment if not os.path.exists(asset): @@ -36,13 +68,18 @@ def _get_file_cache_key(self, asset: str) -> str: stat = os.stat(asset) return f"{env}@{asset}:{stat.st_size}:{stat.st_mtime_ns}" - def _get_url_cache_key(self, url: str) -> str: + def get_url_cache_key(self, url: str) -> str: """Generate cache key for a URL, including environment.""" env = self.openapi_service.environment return f"{env}@{url}" def _upload_url_asset(self, url: str) -> str: - """Upload a URL asset, with optional caching.""" + """ + Upload a URL asset with caching. + + URLs are always cached in-memory (lightweight, no disk I/O overhead). + Caching is required for the two-step upload flow and cannot be disabled. + """ def upload_url() -> str: response = self.openapi_service.asset_api.asset_url_post(url=url) @@ -51,17 +88,15 @@ def upload_url() -> str: ) return response.file_name - if not rapidata_config.upload.cacheUploads: - return upload_url() - - return self._url_cache.get_or_fetch( - self._get_url_cache_key(url), - upload_url, - should_cache=rapidata_config.upload.cacheUploads, - ) + return self._url_cache.get_or_fetch(self.get_url_cache_key(url), upload_url) def _upload_file_asset(self, file_path: str) -> str: - """Upload a local file asset, with optional caching.""" + """ + Upload a local file asset with caching. + + Caching is always enabled as it's required for the two-step upload flow. + Use cacheToDisk config to control whether cache is stored to disk or memory. + """ def upload_file() -> str: response = self.openapi_service.asset_api.asset_file_post(file=file_path) @@ -72,13 +107,8 @@ def upload_file() -> str: ) return response.file_name - if not rapidata_config.upload.cacheUploads: - return upload_file() - - return self._file_cache.get_or_fetch( - self._get_file_cache_key(file_path), - upload_file, - should_cache=rapidata_config.upload.cacheUploads, + return self._get_file_cache().get_or_fetch( + self.get_file_cache_key(file_path), upload_file ) def upload_asset(self, asset: str) -> str: @@ -91,8 +121,9 @@ def upload_asset(self, asset: str) -> str: return self._upload_file_asset(asset) - def clear_cache(self): - self._file_cache.clear() + def clear_cache(self) -> None: + """Clear both URL and file caches.""" + self._get_file_cache().clear() self._url_cache.clear() logger.info("Upload cache cleared") diff --git a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py new file mode 100644 index 000000000..f5ca14c30 --- /dev/null +++ b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import time +import threading +from typing import Callable, TYPE_CHECKING + +from rapidata.rapidata_client.config import logger, rapidata_config +from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload +from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader +from rapidata.api_client.models.batch_upload_status import BatchUploadStatus +from rapidata.api_client.models.batch_upload_url_status import BatchUploadUrlStatus +from rapidata.api_client.models.create_batch_upload_endpoint_input import ( + CreateBatchUploadEndpointInput, +) + +if TYPE_CHECKING: + from rapidata.service.openapi_service import OpenAPIService + from rapidata.api_client.models.get_batch_upload_status_endpoint_output import ( + GetBatchUploadStatusEndpointOutput, + ) + + +class BatchAssetUploader: + """ + Handles batch uploading of URL assets using the batch upload API. + + This class submits URLs in batches, polls for completion, and updates + the shared URL cache with successful uploads. + """ + + def __init__(self, openapi_service: OpenAPIService) -> None: + self.openapi_service = openapi_service + self.asset_uploader = AssetUploader(openapi_service) + self.url_cache = AssetUploader._url_cache + self._interrupted = False + + def batch_upload_urls( + self, + urls: list[str], + progress_callback: Callable[[int], None] | None = None, + completion_callback: Callable[[list[str]], None] | None = None, + ) -> list[FailedUpload[str]]: + """ + Upload URLs in batches. Returns list of failed uploads. + Successful uploads are cached automatically. + + Batches are submitted concurrently with polling - polling starts as soon + as the first batch is submitted, allowing progress to be visible immediately. + + Args: + urls: List of URLs to upload. + progress_callback: Optional callback to report progress (called with number of newly completed items). + completion_callback: Optional callback to notify when URLs complete (called with list of successful URLs). + + Returns: + List of FailedUpload instances for any URLs that failed. + """ + if not urls: + return [] + + # Split into batches + batches = self._split_into_batches(urls) + + # Thread-safe collections for concurrent submission and polling + batch_ids_lock = threading.Lock() + batch_ids: list[str] = [] + batch_to_urls: dict[str, list[str]] = {} + submission_complete = threading.Event() + + try: + # Submit batches in background thread + def submit_batches_background(): + """Submit all batches and signal completion.""" + for batch_idx, batch in enumerate(batches): + # Check if interrupted before submitting next batch + if self._interrupted: + logger.debug("Batch submission stopped due to interruption") + break + + try: + result = self.openapi_service.batch_upload_api.asset_batch_upload_post( + create_batch_upload_endpoint_input=CreateBatchUploadEndpointInput( + urls=batch + ) + ) + batch_id = result.batch_upload_id + + # Add to shared collections (thread-safe) + with batch_ids_lock: + batch_ids.append(batch_id) + batch_to_urls[batch_id] = batch + + logger.debug( + f"Submitted batch {batch_idx + 1}/{len(batches)}: {batch_id}" + ) + except Exception as e: + logger.error(f"Failed to submit batch {batch_idx + 1}: {e}") + + # Signal that all batches have been submitted + submission_complete.set() + with batch_ids_lock: + logger.info( + f"Successfully submitted {len(batch_ids)}/{len(batches)} batches" + ) + + # Start background submission + submission_thread = threading.Thread( + target=submit_batches_background, daemon=True + ) + submission_thread.start() + + # Wait for at least one batch to be submitted before starting poll + while len(batch_ids) == 0 and not submission_complete.is_set(): + time.sleep(0.5) + + if len(batch_ids) == 0: + logger.error("No batches were successfully submitted") + return self._create_submission_failures(urls) + + # Poll until complete (will handle dynamically growing batch list) + return self._poll_until_complete( + batch_ids, + batch_to_urls, + batch_ids_lock, + submission_complete, + progress_callback, + completion_callback, + ) + + except KeyboardInterrupt: + logger.warning("Batch upload interrupted by user (Ctrl+C)") + self._interrupted = True + raise # Re-raise to propagate interruption + + finally: + # Cleanup: abort batches if interrupted + if self._interrupted: + self._abort_batches(batch_ids, batch_ids_lock) + + def _split_into_batches(self, urls: list[str]) -> list[list[str]]: + """Split URLs into batches of configured size.""" + batch_size = rapidata_config.upload.batchSize + batches = [urls[i : i + batch_size] for i in range(0, len(urls), batch_size)] + logger.info(f"Submitting {len(urls)} URLs in {len(batches)} batch(es)") + return batches + + def _poll_until_complete( + self, + batch_ids: list[str], + batch_to_urls: dict[str, list[str]], + batch_ids_lock: threading.Lock, + submission_complete: threading.Event, + progress_callback: Callable[[int], None] | None, + completion_callback: Callable[[list[str]], None] | None, + ) -> list[FailedUpload[str]]: + """ + Poll batches until all complete. Process batches incrementally as they complete. + + Supports concurrent batch submission - will poll currently submitted batches + and continue until all batches are submitted and completed. + + Args: + batch_ids: Shared list of batch IDs (grows as batches are submitted). + batch_to_urls: Shared mapping from batch_id to list of URLs in that batch. + batch_ids_lock: Lock protecting batch_ids and batch_to_urls. + submission_complete: Event signaling all batches have been submitted. + progress_callback: Optional callback to report progress. + completion_callback: Optional callback to notify when URLs complete. + + Returns: + List of FailedUpload instances for any URLs that failed. + """ + poll_interval = rapidata_config.upload.batchPollInterval + + last_completed = 0 + start_time = time.time() + processed_batches: set[str] = set() + all_failures: list[FailedUpload[str]] = [] + + while True: + # Check for interruption at start of each iteration + if self._interrupted: + logger.debug("Polling stopped due to interruption") + break + + # Get current batch IDs (thread-safe) + with batch_ids_lock: + current_batch_ids = batch_ids.copy() + total_batches_submitted = len(current_batch_ids) + + if not current_batch_ids: + # No batches yet, wait a bit + time.sleep(poll_interval) + continue + + logger.debug( + f"Polling {total_batches_submitted} submitted batch(es) for completion" + ) + try: + status = ( + self.openapi_service.batch_upload_api.asset_batch_upload_status_get( + batch_upload_ids=current_batch_ids + ) + ) + + # Process newly completed batches + for batch_id in status.completed_batches: + if batch_id not in processed_batches: + successful_urls, failures = self._process_single_batch( + batch_id, batch_to_urls + ) + processed_batches.add(batch_id) + all_failures.extend(failures) + + # Notify callback with completed URLs + if completion_callback and successful_urls: + completion_callback(successful_urls) + + # Update progress + self._update_progress(status, last_completed, progress_callback) + last_completed = status.completed_count + status.failed_count + + # Check if we're done: + # 1. All batches have been submitted + # 2. All submitted batches have been processed + if submission_complete.is_set() and len(processed_batches) == len( + current_batch_ids + ): + elapsed = time.time() - start_time + logger.info( + f"All batches completed in {elapsed:.1f}s: " + f"{status.completed_count} succeeded, {status.failed_count} failed" + ) + return all_failures + + # Wait before next poll + time.sleep(poll_interval) + + except Exception as e: + logger.error(f"Error polling batch status: {e}") + time.sleep(poll_interval) + + # Return failures collected so far (reached via break on interruption) + return all_failures + + def _update_progress( + self, + status: GetBatchUploadStatusEndpointOutput, + last_completed: int, + progress_callback: Callable[[int], None] | None, + ) -> None: + """Update progress callback if provided.""" + if progress_callback: + new_completed = status.completed_count + status.failed_count + if new_completed > last_completed: + progress_callback(new_completed - last_completed) + + def _process_single_batch( + self, batch_id: str, batch_to_urls: dict[str, list[str]] + ) -> tuple[list[str], list[FailedUpload[str]]]: + """ + Fetch and cache results for a single batch. + + Args: + batch_id: The batch ID to process. + batch_to_urls: Mapping from batch_id to list of URLs in that batch. + + Returns: + Tuple of (successful_urls, failed_uploads). + """ + successful_urls: list[str] = [] + failed_uploads: list[FailedUpload[str]] = [] + + try: + result = self.openapi_service.batch_upload_api.asset_batch_upload_batch_upload_id_get( + batch_upload_id=batch_id + ) + + # Process each URL in the batch result + for item in result.items: + if item.status == BatchUploadUrlStatus.COMPLETED: + # Cache successful upload using proper API + if item.file_name is not None: + cache_key = self.asset_uploader.get_url_cache_key(item.url) + self.url_cache.set(cache_key, item.file_name) + successful_urls.append(item.url) + logger.debug( + f"Cached successful upload: {item.url} -> {item.file_name}" + ) + else: + logger.warning( + f"Batch upload completed but file_name is None for URL: {item.url}" + ) + failed_uploads.append( + FailedUpload( + item=item.url, + error_type="BatchUploadFailed", + error_message="Upload completed but file_name is None", + ) + ) + else: + # Track failure + failed_uploads.append( + FailedUpload( + item=item.url, + error_type="BatchUploadFailed", + error_message=item.error_message + or "Unknown batch upload error", + ) + ) + logger.warning( + f"URL failed in batch: {item.url} - {item.error_message}" + ) + + except Exception as e: + logger.error(f"Failed to fetch results for batch {batch_id}: {e}") + # Create individual FailedUpload for each URL in the failed batch + # Use from_exception to extract proper error reason from RapidataError + if batch_id in batch_to_urls: + for url in batch_to_urls[batch_id]: + failed_uploads.append(FailedUpload.from_exception(url, e)) + else: + # Fallback if batch_id not found in mapping + failed_uploads.append( + FailedUpload.from_exception(f"batch_{batch_id}", e) + ) + + if successful_urls: + logger.debug( + f"Batch {batch_id}: {len(successful_urls)} succeeded, {len(failed_uploads)} failed" + ) + + return successful_urls, failed_uploads + + def _create_submission_failures(self, urls: list[str]) -> list[FailedUpload[str]]: + """Create FailedUpload instances for all URLs when submission fails.""" + return [ + FailedUpload( + item=url, + error_type="BatchSubmissionFailed", + error_message="Failed to submit any batches", + ) + for url in urls + ] + + def _abort_batches( + self, + batch_ids: list[str], + batch_ids_lock: threading.Lock, + ) -> None: + """ + Abort all submitted batches by calling the abort endpoint. + + This method is called during cleanup when the upload process is interrupted. + It attempts to abort all batches that were successfully submitted. + + Args: + batch_ids: Shared list of batch IDs (thread-safe access required). + batch_ids_lock: Lock protecting batch_ids list. + """ + # Get snapshot of current batch IDs + with batch_ids_lock: + batches_to_abort = batch_ids.copy() + + if not batches_to_abort: + logger.info("No batches to abort") + return + + logger.info( + f"Aborting {len(batches_to_abort)} batch(es) due to interruption..." + ) + + abort_successes = 0 + abort_failures = 0 + + for batch_id in batches_to_abort: + try: + self.openapi_service.batch_upload_api.asset_batch_upload_batch_upload_id_abort_post( + batch_upload_id=batch_id + ) + abort_successes += 1 + logger.debug(f"Successfully aborted batch: {batch_id}") + except Exception as e: + abort_failures += 1 + logger.warning(f"Failed to abort batch {batch_id}: {e}") + + logger.info( + f"Batch abort completed: {abort_successes} succeeded, {abort_failures} failed" + ) diff --git a/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py b/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py index c7ae48cd5..e3f3b1c8e 100644 --- a/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py +++ b/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py @@ -29,6 +29,10 @@ def set_storage(self, storage: dict[str, str] | FanoutCache) -> None: except Exception: pass + def get_storage(self) -> dict[str, str] | FanoutCache: + """Get the cache storage.""" + return self._storage + def get_or_fetch( self, key: str, @@ -77,6 +81,10 @@ def get_or_fetch( with self._lock: self._in_flight.pop(key, None) + def set(self, key: str, value: str) -> None: + """Set a value in the cache.""" + self._storage[key] = value + def clear(self) -> None: """Clear the cache.""" self._storage.clear() diff --git a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py index fd1b907f3..c6cfb9ef7 100644 --- a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py +++ b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py @@ -1,22 +1,74 @@ +""" +RapidataDataset module for managing datapoint uploads with incremental asset processing. + +Threading Model: +--------------- +This module uses concurrent processing for both asset uploads and datapoint creation: + +1. Asset Upload Phase (Step 1/2): + - URLs are uploaded in batches with polling for completion + - Files are uploaded in parallel using ThreadPoolExecutor + - Completion callbacks are invoked from worker threads + +2. Datapoint Creation Phase (Step 2/2): + - Datapoints are created incrementally as their required assets complete + - Uses ThreadPoolExecutor with max_workers=rapidata_config.upload.maxWorkers + - Callbacks from asset upload trigger datapoint creation submissions + +Thread-Safety: +------------- +- `lock` (threading.Lock): Protects all shared state during incremental processing + - `datapoint_pending_count`: Maps datapoint index to remaining asset count + - `asset_to_datapoints`: Maps asset to set of datapoint indices waiting for it + - `creation_futures`: List of (idx, Future) tuples for datapoint creation tasks + +Lock Acquisition Order: +---------------------- +1. `on_assets_complete` callback acquires lock to update shared state +2. Lock is released before submitting datapoint creation tasks to avoid blocking +3. Lock is re-acquired briefly to append futures to creation_futures list + +The callback-based design ensures: +- Assets can complete incrementally (batch-by-batch, file-by-file) +- Datapoints are created as soon as all their assets are ready +- No deadlocks occur between asset completion and datapoint submission +""" + +from __future__ import annotations + +import threading +from concurrent.futures import ThreadPoolExecutor, Future +from typing import Callable + +from tqdm.auto import tqdm + from rapidata.rapidata_client.datapoints._datapoint import Datapoint from rapidata.service.openapi_service import OpenAPIService from rapidata.rapidata_client.datapoints._datapoint_uploader import DatapointUploader -from rapidata.rapidata_client.utils.threaded_uploader import ThreadedUploader +from rapidata.rapidata_client.datapoints._asset_upload_orchestrator import ( + AssetUploadOrchestrator, + extract_assets_from_datapoint, +) from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload +from rapidata.rapidata_client.config import rapidata_config, logger class RapidataDataset: - def __init__(self, dataset_id: str, openapi_service: OpenAPIService): + def __init__(self, dataset_id: str, openapi_service: OpenAPIService) -> None: self.id = dataset_id self.openapi_service = openapi_service self.datapoint_uploader = DatapointUploader(openapi_service) + self.asset_orchestrator = AssetUploadOrchestrator(openapi_service) def add_datapoints( self, datapoints: list[Datapoint], ) -> tuple[list[Datapoint], list[FailedUpload[Datapoint]]]: """ - Process uploads in chunks with a ThreadPoolExecutor. + Upload datapoints with incremental creation: + - Start uploading all assets (URLs in batches + files in parallel) + - As assets complete, check which datapoints are ready and create them + - Continue until all uploads and datapoint creation complete Args: datapoints: List of datapoints to upload @@ -24,21 +76,317 @@ def add_datapoints( Returns: tuple[list[Datapoint], list[FailedUpload[Datapoint]]]: Lists of successful uploads and failed uploads with error details """ + if not datapoints: + return [], [] - def upload_single_datapoint(datapoint: Datapoint, index: int) -> None: - self.datapoint_uploader.upload_datapoint( - dataset_id=self.id, - datapoint=datapoint, - index=index, - ) + # 1. Build asset-to-datapoint mappings + asset_to_datapoints, datapoint_pending_count = ( + self._build_asset_to_datapoint_mapping(datapoints) + ) - uploader: ThreadedUploader[Datapoint] = ThreadedUploader( - upload_fn=upload_single_datapoint, - description="Uploading datapoints", + # 2. Set up shared state for incremental creation + creation_futures: list[tuple[int, Future]] = [] + lock = threading.Lock() + executor = ThreadPoolExecutor(max_workers=rapidata_config.upload.maxWorkers) + + # 3. Execute uploads and incremental datapoint creation + self._execute_incremental_creation( + datapoints, + asset_to_datapoints, + datapoint_pending_count, + creation_futures, + lock, + executor, + ) + + # 4. Collect and return results + return self._collect_and_return_results( + datapoints, creation_futures, datapoint_pending_count, lock ) - successful_uploads, failed_uploads = uploader.upload(datapoints) + def _build_asset_to_datapoint_mapping( + self, datapoints: list[Datapoint] + ) -> tuple[dict[str, set[int]], dict[int, int]]: + """ + Build efficient reverse mapping: asset -> datapoint indices that need it. + This allows O(1) lookup instead of checking all pending datapoints. + + Note on using indices: We use integer indices instead of datapoint objects because: + - Indices are hashable (required for dict keys and sets) + - Lightweight and fast to compare + - Provides stable references to datapoints in the list + - The datapoint objects remain in a single location (the datapoints list) + + Args: + datapoints: List of datapoints to process. Indices into this list are used as identifiers. + + Returns: + Tuple of (asset_to_datapoints, datapoint_pending_count): + - asset_to_datapoints: Maps each asset to set of datapoint indices (positions in datapoints list) that need it + - datapoint_pending_count: Maps each datapoint index to count of remaining assets it needs + """ + # Using indices instead of datapoints directly for hashability and performance + asset_to_datapoints: dict[str, set[int]] = {} + datapoint_pending_count: dict[int, int] = {} + + for idx, dp in enumerate(datapoints): + # Extract all assets for this datapoint using shared utility + assets = extract_assets_from_datapoint(dp) + + # Track how many assets this datapoint needs + datapoint_pending_count[idx] = len(assets) + # Build reverse mapping + for asset in assets: + if asset not in asset_to_datapoints: + asset_to_datapoints[asset] = set() + asset_to_datapoints[asset].add(idx) + + logger.debug(f"Mapped {len(datapoints)} datapoints to their required assets") + return asset_to_datapoints, datapoint_pending_count + + def _execute_incremental_creation( + self, + datapoints: list[Datapoint], + asset_to_datapoints: dict[str, set[int]], + datapoint_pending_count: dict[int, int], + creation_futures: list[tuple[int, Future]], + lock: threading.Lock, + executor: ThreadPoolExecutor, + ) -> None: + """ + Execute asset uploads and incremental datapoint creation. + + Args: + datapoints: List of datapoints being processed. + asset_to_datapoints: Mapping from asset to datapoint indices. + datapoint_pending_count: Pending asset count per datapoint. + creation_futures: List to store creation futures. + lock: Lock protecting shared state. + executor: Thread pool executor for datapoint creation. + """ + # Create progress bar for datapoint creation + datapoint_pbar = tqdm( + total=len(datapoints), + desc="Step 2/2: Creating datapoints", + position=1, + disable=rapidata_config.logging.silent_mode, + leave=True, + ) + + try: + # Create callback that submits datapoints for creation + on_assets_complete = self._create_asset_completion_callback( + datapoints, + asset_to_datapoints, + datapoint_pending_count, + creation_futures, + lock, + executor, + datapoint_pbar, + ) + + # Extract all unique assets from the mapping + all_assets = set(asset_to_datapoints.keys()) + + # Start uploads (blocking, but triggers callbacks as assets complete) + logger.info("Starting incremental datapoint creation") + asset_failures = self.asset_orchestrator.upload_all_assets( + all_assets, asset_completion_callback=on_assets_complete + ) + + if asset_failures: + logger.warning( + f"{len(asset_failures)} asset(s) failed to upload, affected datapoints will be marked as failed" + ) + + # Wait for all datapoint creation to complete + executor.shutdown(wait=True) + logger.debug("All datapoint creation tasks completed") + finally: + # Always close progress bar, even on exception + datapoint_pbar.close() + + def _create_asset_completion_callback( + self, + datapoints: list[Datapoint], + asset_to_datapoints: dict[str, set[int]], + datapoint_pending_count: dict[int, int], + creation_futures: list[tuple[int, Future]], + lock: threading.Lock, + executor: ThreadPoolExecutor, + datapoint_pbar: tqdm, + ) -> Callable[[list[str]], None]: + """ + Create callback function that handles asset completion. + + THREAD-SAFETY: The returned callback is invoked from worker threads during asset upload. + All access to shared state is protected by the lock. + + Args: + datapoints: List of datapoints being processed. + asset_to_datapoints: Mapping from asset to datapoint indices. + datapoint_pending_count: Pending asset count per datapoint. + creation_futures: List to store creation futures. + lock: Lock protecting shared state. + executor: Thread pool executor for datapoint creation. + datapoint_pbar: Progress bar for datapoint creation. + + Returns: + Callback function to be invoked when assets complete. + """ + + def on_assets_complete(assets: list[str]) -> None: + """Called when a batch of assets completes uploading.""" + ready_datapoint_indices = self._find_ready_datapoints( + assets, asset_to_datapoints, datapoint_pending_count, lock + ) + + # Submit ready datapoints for creation (outside lock to avoid blocking) + self._submit_datapoints_for_creation( + ready_datapoint_indices, + datapoints, + creation_futures, + lock, + executor, + datapoint_pbar, + ) + + return on_assets_complete + + def _find_ready_datapoints( + self, + assets: list[str], + asset_to_datapoints: dict[str, set[int]], + datapoint_pending_count: dict[int, int], + lock: threading.Lock, + ) -> list[int]: + """ + Find datapoints that are ready for creation after asset completion. + + Returns indices into the original datapoints list for datapoints whose + assets are now all complete. + + Args: + assets: List of completed assets. + asset_to_datapoints: Mapping from asset to datapoint indices. + datapoint_pending_count: Pending asset count per datapoint index. + lock: Lock protecting shared state. + + Returns: + List of datapoint indices (positions in the datapoints list) ready for creation. + """ + ready_datapoint_indices = [] + + with lock: + # For each completed asset, find datapoints that need it + for asset in assets: + if asset in asset_to_datapoints: + # Get all datapoint indices waiting for this asset + for datapoint_idx in list(asset_to_datapoints[asset]): + if datapoint_idx in datapoint_pending_count: + # Decrement the remaining asset count for this datapoint + datapoint_pending_count[datapoint_idx] -= 1 + + # If all assets are now ready, mark this datapoint for creation + if datapoint_pending_count[datapoint_idx] == 0: + ready_datapoint_indices.append(datapoint_idx) + del datapoint_pending_count[datapoint_idx] + + # Remove this datapoint from this asset's waiting list + asset_to_datapoints[asset].discard(datapoint_idx) + + return ready_datapoint_indices + + def _submit_datapoints_for_creation( + self, + ready_datapoint_indices: list[int], + datapoints: list[Datapoint], + creation_futures: list[tuple[int, Future]], + lock: threading.Lock, + executor: ThreadPoolExecutor, + datapoint_pbar: tqdm, + ) -> None: + """ + Submit ready datapoints for creation. + + Args: + ready_datapoint_indices: Indices (positions in datapoints list) of datapoints ready for creation. + datapoints: List of all datapoints (indexed by ready_datapoint_indices). + creation_futures: List to store creation futures. + lock: Lock protecting creation_futures. + executor: Thread pool executor for datapoint creation. + datapoint_pbar: Progress bar for datapoint creation. + """ + for datapoint_idx in ready_datapoint_indices: + + def upload_and_update(dp_idx): + """Upload datapoint and update progress bar when done.""" + try: + self.datapoint_uploader.upload_datapoint( + dataset_id=self.id, + datapoint=datapoints[dp_idx], + index=dp_idx, + ) + finally: + datapoint_pbar.update(1) + + future = executor.submit(upload_and_update, datapoint_idx) + with lock: + creation_futures.append((datapoint_idx, future)) + + if ready_datapoint_indices: + logger.debug( + f"Asset batch completed, {len(ready_datapoint_indices)} datapoints now ready for creation" + ) + + def _collect_and_return_results( + self, + datapoints: list[Datapoint], + creation_futures: list[tuple[int, Future]], + datapoint_pending_count: dict[int, int], + lock: threading.Lock, + ) -> tuple[list[Datapoint], list[FailedUpload[Datapoint]]]: + """ + Collect results from datapoint creation tasks. + + Args: + datapoints: List of all datapoints. + creation_futures: List of creation futures. + datapoint_pending_count: Datapoints whose assets failed. + lock: Lock protecting datapoint_pending_count. + + Returns: + Tuple of (successful_uploads, failed_uploads). + """ + successful_uploads: list[Datapoint] = [] + failed_uploads: list[FailedUpload[Datapoint]] = [] + + # Collect results from creation tasks + for idx, future in creation_futures: + try: + future.result() # Raises exception if failed + successful_uploads.append(datapoints[idx]) + except Exception as e: + logger.warning(f"Failed to create datapoint {idx}: {e}") + # Use from_exception to extract proper error reason from RapidataError + failed_uploads.append(FailedUpload.from_exception(datapoints[idx], e)) + + # Handle datapoints whose assets failed to upload + with lock: + for idx in datapoint_pending_count: + logger.warning(f"Datapoint {idx} assets failed to upload") + failed_uploads.append( + FailedUpload( + item=datapoints[idx], + error_type="AssetUploadFailed", + error_message="One or more required assets failed to upload", + ) + ) + + logger.info( + f"Datapoint creation complete: {len(successful_uploads)} succeeded, {len(failed_uploads)} failed" + ) return successful_uploads, failed_uploads def __str__(self) -> str: diff --git a/src/rapidata/rapidata_client/exceptions/__init__.py b/src/rapidata/rapidata_client/exceptions/__init__.py index 082165944..86ca12215 100644 --- a/src/rapidata/rapidata_client/exceptions/__init__.py +++ b/src/rapidata/rapidata_client/exceptions/__init__.py @@ -1,5 +1,6 @@ +from .asset_upload_exception import AssetUploadException from .failed_upload import FailedUpload from .failed_upload_exception import FailedUploadException from .rapidata_error import RapidataError -__all__ = ["FailedUpload", "FailedUploadException", "RapidataError"] +__all__ = ["AssetUploadException", "FailedUpload", "FailedUploadException", "RapidataError"] diff --git a/src/rapidata/rapidata_client/exceptions/asset_upload_exception.py b/src/rapidata/rapidata_client/exceptions/asset_upload_exception.py new file mode 100644 index 000000000..077eb25ee --- /dev/null +++ b/src/rapidata/rapidata_client/exceptions/asset_upload_exception.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload + + +class AssetUploadException(Exception): + """ + Exception raised when asset uploads fail in Step 1/2 of the two-step upload process. + + This exception contains details about which assets failed to upload, + allowing users to decide how to proceed (retry, skip, or abort). + + Attributes: + failed_uploads: List of FailedUpload instances containing the failed assets and error details. + """ + + def __init__(self, failed_uploads: list[FailedUpload[str]]): + self.failed_uploads = failed_uploads + message = ( + f"Failed to upload {len(failed_uploads)} asset(s) in Step 1/2. " + f"See failed_uploads attribute for details." + ) + super().__init__(message) + + def __str__(self) -> str: + error_summary = "\n".join( + f" - {fu.item}: {fu.error_message}" for fu in self.failed_uploads[:5] + ) + if len(self.failed_uploads) > 5: + error_summary += f"\n ... and {len(self.failed_uploads) - 5} more" + return f"{super().__str__()}\n{error_summary}" diff --git a/src/rapidata/rapidata_client/exceptions/failed_upload.py b/src/rapidata/rapidata_client/exceptions/failed_upload.py index 1a35e65b5..29b7ff910 100644 --- a/src/rapidata/rapidata_client/exceptions/failed_upload.py +++ b/src/rapidata/rapidata_client/exceptions/failed_upload.py @@ -1,6 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import TypeVar, Generic, TYPE_CHECKING +from dataclasses import dataclass, field +from datetime import datetime +from typing import TypeVar, Generic, TYPE_CHECKING, Optional if TYPE_CHECKING: from rapidata.rapidata_client.exceptions.rapidata_error import RapidataError @@ -17,11 +18,15 @@ class FailedUpload(Generic[T]): item: The item that failed to upload. error_message: The error message describing the failure reason. error_type: The type of the exception (e.g., "RapidataError"). + timestamp: Optional timestamp when the failure occurred. + exception: Optional original exception for richer error context. """ item: T error_message: str error_type: str + timestamp: Optional[datetime] = field(default_factory=datetime.now) + exception: Optional[Exception] = None @classmethod def from_exception(cls, item: T, exception: Exception | None) -> FailedUpload[T]: @@ -43,6 +48,7 @@ def from_exception(cls, item: T, exception: Exception | None) -> FailedUpload[T] item=item, error_message="Unknown error", error_type="Unknown", + exception=None, ) from rapidata.rapidata_client.exceptions.rapidata_error import RapidataError @@ -58,7 +64,26 @@ def from_exception(cls, item: T, exception: Exception | None) -> FailedUpload[T] item=item, error_message=error_message, error_type=error_type, + exception=exception, ) + def format_error_details(self) -> str: + """ + Format error details for logging or display. + + Returns: + Formatted string with all error details including timestamp. + """ + details = [ + f"Item: {self.item}", + f"Error Type: {self.error_type}", + f"Error Message: {self.error_message}", + ] + + if self.timestamp: + details.append(f"Timestamp: {self.timestamp.isoformat()}") + + return "\n".join(details) + def __str__(self) -> str: return f"{self.item}" diff --git a/src/rapidata/rapidata_client/job/rapidata_job.py b/src/rapidata/rapidata_client/job/rapidata_job.py index fb57247d4..632819124 100644 --- a/src/rapidata/rapidata_client/job/rapidata_job.py +++ b/src/rapidata/rapidata_client/job/rapidata_job.py @@ -7,7 +7,7 @@ from time import sleep from typing import Callable, TypeVar, TYPE_CHECKING from colorama import Fore -from tqdm import tqdm +from tqdm.auto import tqdm from rapidata.service.openapi_service import OpenAPIService from rapidata.rapidata_client.config import ( diff --git a/src/rapidata/rapidata_client/order/rapidata_order.py b/src/rapidata/rapidata_client/order/rapidata_order.py index 60a00d473..4f272df91 100644 --- a/src/rapidata/rapidata_client/order/rapidata_order.py +++ b/src/rapidata/rapidata_client/order/rapidata_order.py @@ -9,7 +9,7 @@ from typing import cast, Callable, TypeVar, TYPE_CHECKING from colorama import Fore from datetime import datetime -from tqdm import tqdm +from tqdm.auto import tqdm # Local/application imports from rapidata.service.openapi_service import OpenAPIService @@ -336,7 +336,9 @@ def get_results(self, preliminary_results: bool = False) -> RapidataResults: with tracer.start_as_current_span("RapidataOrder.get_results"): from rapidata.api_client.models.order_state import OrderState from rapidata.api_client.exceptions import ApiException - from rapidata.rapidata_client.results.rapidata_results import RapidataResults + from rapidata.rapidata_client.results.rapidata_results import ( + RapidataResults, + ) logger.info("Getting results for order '%s'...", self) diff --git a/src/rapidata/rapidata_client/order/rapidata_order_manager.py b/src/rapidata/rapidata_client/order/rapidata_order_manager.py index c80026151..f251e2d2e 100644 --- a/src/rapidata/rapidata_client/order/rapidata_order_manager.py +++ b/src/rapidata/rapidata_client/order/rapidata_order_manager.py @@ -102,7 +102,7 @@ def _create_general_order( "Creating order with parameters: name %s, workflow %s, datapoints %s, responses_per_datapoint %s, validation_set_id %s, confidence_threshold %s, filters %s, settings %s, selections %s", name, workflow, - datapoints, + len(datapoints), responses_per_datapoint, validation_set_id, confidence_threshold, diff --git a/src/rapidata/rapidata_client/utils/threaded_uploader.py b/src/rapidata/rapidata_client/utils/threaded_uploader.py index b326c6014..a832f9495 100644 --- a/src/rapidata/rapidata_client/utils/threaded_uploader.py +++ b/src/rapidata/rapidata_client/utils/threaded_uploader.py @@ -1,6 +1,6 @@ from typing import TypeVar, Generic, Callable from concurrent.futures import ThreadPoolExecutor, as_completed -from tqdm import tqdm +from tqdm.auto import tqdm import time from rapidata.rapidata_client.config import logger diff --git a/src/rapidata/rapidata_client/validation/validation_set_manager.py b/src/rapidata/rapidata_client/validation/validation_set_manager.py index 4ce03d371..f58bcd5a4 100644 --- a/src/rapidata/rapidata_client/validation/validation_set_manager.py +++ b/src/rapidata/rapidata_client/validation/validation_set_manager.py @@ -30,7 +30,7 @@ rapidata_config, tracer, ) -from tqdm import tqdm +from tqdm.auto import tqdm from rapidata.rapidata_client.validation.rapids.rapids import Rapid if TYPE_CHECKING: diff --git a/src/rapidata/service/openapi_service.py b/src/rapidata/service/openapi_service.py index 7295eb151..8ad33926e 100644 --- a/src/rapidata/service/openapi_service.py +++ b/src/rapidata/service/openapi_service.py @@ -24,6 +24,7 @@ from rapidata.api_client.api.workflow_api import WorkflowApi from rapidata.api_client.api.participant_api import ParticipantApi from rapidata.api_client.api.audience_api import AudienceApi + from rapidata.api_client.api.batch_upload_api import BatchUploadApi class OpenAPIService: @@ -142,6 +143,12 @@ def asset_api(self) -> AssetApi: return AssetApi(self.api_client) + @property + def batch_upload_api(self) -> BatchUploadApi: + from rapidata.api_client.api.batch_upload_api import BatchUploadApi + + return BatchUploadApi(self.api_client) + @property def dataset_api(self) -> DatasetApi: from rapidata.api_client.api.dataset_api import DatasetApi diff --git a/src/rapidata/types/__init__.py b/src/rapidata/types/__init__.py index b7079e44f..481e03455 100644 --- a/src/rapidata/types/__init__.py +++ b/src/rapidata/types/__init__.py @@ -71,7 +71,6 @@ # Configuration Types -from rapidata.rapidata_client.config.order_config import OrderConfig from rapidata.rapidata_client.config.upload_config import UploadConfig from rapidata.rapidata_client.config.rapidata_config import RapidataConfig from rapidata.rapidata_client.config.logging_config import LoggingConfig @@ -136,7 +135,6 @@ "TranslationBehaviour", "TranslationBehaviourOptions", # Configuration Types - "OrderConfig", "UploadConfig", "RapidataConfig", "LoggingConfig",