From c3075ba53a24d1a4f7b34f58c093b3bf8f88cf0f Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Tue, 27 Jan 2026 13:30:53 +0100 Subject: [PATCH 01/21] initial setup --- .../rapidata_client/config/upload_config.py | 36 +++ .../datapoints/_asset_upload_orchestrator.py | 192 ++++++++++++++++ .../datapoints/_batch_asset_uploader.py | 207 ++++++++++++++++++ .../dataset/_rapidata_dataset.py | 20 +- .../rapidata_client/exceptions/__init__.py | 3 +- .../exceptions/asset_upload_exception.py | 31 +++ src/rapidata/service/openapi_service.py | 7 + 7 files changed, 493 insertions(+), 3 deletions(-) create mode 100644 src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py create mode 100644 src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py create mode 100644 src/rapidata/rapidata_client/exceptions/asset_upload_exception.py diff --git a/src/rapidata/rapidata_client/config/upload_config.py b/src/rapidata/rapidata_client/config/upload_config.py index b2981c390..40c9de7f3 100644 --- a/src/rapidata/rapidata_client/config/upload_config.py +++ b/src/rapidata/rapidata_client/config/upload_config.py @@ -18,6 +18,11 @@ class UploadConfig(BaseModel): cacheShards (int): Number of cache shards for parallel access. Defaults to 128. Higher values improve concurrency but increase file handles. Must be positive. This is immutable + enableBatchUpload (bool): Enable batch URL uploading (two-step process). Defaults to True. + batchSize (int): Number of URLs per batch (10-500). Defaults to 100. + batchPollInterval (float): Polling interval in seconds. Defaults to 0.5. + batchPollMaxInterval (float): Maximum polling interval. Defaults to 5.0. + batchTimeout (float): Batch upload timeout in seconds. Defaults to 300.0. """ maxWorkers: int = Field(default=25) @@ -32,6 +37,26 @@ class UploadConfig(BaseModel): default=128, frozen=True, ) + enableBatchUpload: bool = Field( + default=True, + description="Enable batch URL uploading (two-step process)", + ) + batchSize: int = Field( + default=100, + description="Number of URLs per batch (10-500)", + ) + batchPollInterval: float = Field( + default=0.5, + description="Polling interval in seconds", + ) + batchPollMaxInterval: float = Field( + default=5.0, + description="Maximum polling interval", + ) + batchTimeout: float = Field( + default=300.0, + description="Batch upload timeout in seconds", + ) @field_validator("maxWorkers") @classmethod @@ -54,6 +79,17 @@ def validate_cache_shards(cls, v: int) -> int: ) return v + @field_validator("batchSize") + @classmethod + def validate_batch_size(cls, v: int) -> int: + if v < 10: + raise ValueError("batchSize must be at least 10") + if v > 500: + logger.warning( + f"batchSize={v} may cause timeouts. Recommend 50-200." + ) + return v + def __init__(self, **kwargs): super().__init__(**kwargs) self._migrate_cache() diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py new file mode 100644 index 000000000..5afd5aa2e --- /dev/null +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import re +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING + +from tqdm import tqdm + +from rapidata.rapidata_client.config import logger, rapidata_config +from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader +from rapidata.rapidata_client.datapoints._batch_asset_uploader import ( + BatchAssetUploader, +) +from rapidata.rapidata_client.exceptions.asset_upload_exception import ( + AssetUploadException, +) +from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload + +if TYPE_CHECKING: + from rapidata.rapidata_client.datapoints._datapoint import Datapoint + from rapidata.service.openapi_service import OpenAPIService + + +class AssetUploadOrchestrator: + """ + Orchestrates Step 1/2: Upload ALL assets from ALL datapoints. + + This class extracts all unique assets, separates URLs from files, + filters cached assets, and uploads uncached assets using batch + upload for URLs and parallel upload for files. + + Raises AssetUploadException if any uploads fail. + """ + + def __init__(self, openapi_service: OpenAPIService): + self.asset_uploader = AssetUploader(openapi_service) + self.batch_uploader = BatchAssetUploader(openapi_service) + + def upload_all_assets(self, datapoints: list[Datapoint]) -> None: + """ + Step 1/2: Upload ALL assets from ALL datapoints. + Throws AssetUploadException if any uploads fail. + + Args: + datapoints: List of datapoints to extract assets from. + + Raises: + AssetUploadException: If any asset uploads fail. + """ + # 1. Extract all unique assets (deduplicate) + all_assets = self._extract_unique_assets(datapoints) + logger.info(f"Extracted {len(all_assets)} unique asset(s) from datapoints") + + if not all_assets: + logger.debug("No assets to upload") + return + + # 2. Separate URLs vs files + urls = [a for a in all_assets if re.match(r"^https?://", a)] + files = [a for a in all_assets if not re.match(r"^https?://", a)] + logger.debug(f"Asset breakdown: {len(urls)} URL(s), {len(files)} file(s)") + + # 3. Filter cached (skip already-uploaded assets) + uncached_urls = self._filter_uncached(urls, self.asset_uploader._url_cache) + uncached_files = self._filter_uncached(files, self.asset_uploader._file_cache) + logger.info( + f"Assets to upload: {len(uncached_urls)} URL(s), {len(uncached_files)} file(s) " + f"(skipped {len(urls) - len(uncached_urls)} cached URL(s), " + f"{len(files) - len(uncached_files)} cached file(s))" + ) + + total = len(uncached_urls) + len(uncached_files) + if total == 0: + logger.debug("All assets cached, nothing to upload") + return + + # 4. Upload with single progress bar + failed_uploads: list[FailedUpload[str]] = [] + with tqdm( + total=total, + desc="Step 1/2: Uploading assets", + disable=rapidata_config.logging.silent_mode, + ) as pbar: + # 4a. Batch upload URLs + if uncached_urls: + logger.debug(f"Batch uploading {len(uncached_urls)} URL(s)") + url_failures = self.batch_uploader.batch_upload_urls( + uncached_urls, progress_callback=lambda n: pbar.update(n) + ) + failed_uploads.extend(url_failures) + else: + logger.debug("No uncached URLs to upload") + + # 4b. Parallel upload files + if uncached_files: + logger.debug(f"Parallel uploading {len(uncached_files)} file(s)") + file_failures = self._upload_files_parallel( + uncached_files, progress_callback=lambda: pbar.update(1) + ) + failed_uploads.extend(file_failures) + else: + logger.debug("No uncached files to upload") + + # 5. If any failures, throw exception (before Step 2) + if failed_uploads: + logger.error( + f"Asset upload failed for {len(failed_uploads)} asset(s) in Step 1/2" + ) + raise AssetUploadException(failed_uploads) + + logger.info("Step 1/2: All assets uploaded successfully") + + def _extract_unique_assets(self, datapoints: list[Datapoint]) -> set[str]: + """Extract all unique assets from all datapoints.""" + assets: set[str] = set() + for dp in datapoints: + # Main asset(s) + if isinstance(dp.asset, list): + assets.update(dp.asset) + else: + assets.add(dp.asset) + # Context asset + if dp.media_context: + assets.add(dp.media_context) + return assets + + def _filter_uncached( + self, assets: list[str], cache + ) -> list[str]: + """Filter out assets that are already cached.""" + uncached = [] + for asset in assets: + try: + # Try to get cache key + if re.match(r"^https?://", asset): + cache_key = f"{self.asset_uploader.openapi_service.environment}@{asset}" + else: + cache_key = self.asset_uploader._get_file_cache_key(asset) + + # Check if in cache + if cache_key not in cache._storage: + uncached.append(asset) + except Exception as e: + # If cache check fails, include in upload list + logger.debug(f"Cache check failed for {asset}: {e}") + uncached.append(asset) + + return uncached + + def _upload_files_parallel( + self, + files: list[str], + progress_callback: callable | None = None, + ) -> list[FailedUpload[str]]: + """ + Upload files in parallel using ThreadPoolExecutor. + + Args: + files: List of file paths to upload. + progress_callback: Optional callback to report progress (called once per completed file). + + Returns: + List of FailedUpload instances for any files that failed. + """ + failed_uploads: list[FailedUpload[str]] = [] + + def upload_single_file(file_path: str) -> FailedUpload[str] | None: + """Upload a single file and return FailedUpload if it fails.""" + try: + self.asset_uploader.upload_asset(file_path) + return None + except Exception as e: + logger.warning(f"Failed to upload file {file_path}: {e}") + return FailedUpload.from_exception(file_path, e) + + with ThreadPoolExecutor( + max_workers=rapidata_config.upload.maxWorkers + ) as executor: + futures = { + executor.submit(upload_single_file, file_path): file_path + for file_path in files + } + + for future in as_completed(futures): + result = future.result() + if result is not None: + failed_uploads.append(result) + + if progress_callback: + progress_callback() + + return failed_uploads diff --git a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py new file mode 100644 index 000000000..a8033b836 --- /dev/null +++ b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import time +from typing import Callable, TYPE_CHECKING + +from rapidata.rapidata_client.config import logger, rapidata_config +from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload +from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader +from rapidata.api_client.models.batch_upload_status import BatchUploadStatus +from rapidata.api_client.models.batch_upload_url_status import BatchUploadUrlStatus +from rapidata.api_client.models.create_batch_upload_endpoint_input import ( + CreateBatchUploadEndpointInput, +) + +if TYPE_CHECKING: + from rapidata.service.openapi_service import OpenAPIService + + +class BatchAssetUploader: + """ + Handles batch uploading of URL assets using the batch upload API. + + This class submits URLs in batches, polls for completion, and updates + the shared URL cache with successful uploads. + """ + + def __init__(self, openapi_service: OpenAPIService): + self.openapi_service = openapi_service + self.url_cache = AssetUploader._url_cache + + def batch_upload_urls( + self, + urls: list[str], + progress_callback: Callable[[int], None] | None = None, + ) -> list[FailedUpload[str]]: + """ + Upload URLs in batches. Returns list of failed uploads. + Successful uploads are cached automatically. + + Args: + urls: List of URLs to upload. + progress_callback: Optional callback to report progress (called with number of newly completed items). + + Returns: + List of FailedUpload instances for any URLs that failed. + """ + if not urls: + return [] + + # Split into batches + batch_size = rapidata_config.upload.batchSize + batches = [urls[i : i + batch_size] for i in range(0, len(urls), batch_size)] + + logger.info(f"Submitting {len(urls)} URLs in {len(batches)} batch(es)") + + # Submit all batches immediately (parallel) + batch_ids = [] + for batch_idx, batch in enumerate(batches): + try: + result = self.openapi_service.batch_upload_api.asset_batch_upload_post( + create_batch_upload_endpoint_input=CreateBatchUploadEndpointInput( + urls=batch + ) + ) + batch_ids.append(result.batch_upload_id) + logger.debug( + f"Submitted batch {batch_idx + 1}/{len(batches)}: {result.batch_upload_id}" + ) + except Exception as e: + logger.error(f"Failed to submit batch {batch_idx + 1}: {e}") + # Fall back to individual uploads for this batch + for url in batch: + failed_upload = FailedUpload( + item=url, + error_type="BatchSubmissionFailed", + error_message=f"Failed to submit batch: {str(e)}", + ) + # Don't return early - try to submit remaining batches + + if not batch_ids: + logger.error("No batches were successfully submitted") + return [ + FailedUpload( + item=url, + error_type="BatchSubmissionFailed", + error_message="Failed to submit any batches", + ) + for url in urls + ] + + # Poll all batches together until complete + logger.debug(f"Polling {len(batch_ids)} batch(es) for completion") + last_completed = 0 + poll_interval = rapidata_config.upload.batchPollInterval + start_time = time.time() + + while True: + try: + status = ( + self.openapi_service.batch_upload_api.asset_batch_upload_status_get( + batch_upload_ids=batch_ids + ) + ) + + # Update progress + if progress_callback: + new_completed = status.completed_count + status.failed_count + if new_completed > last_completed: + progress_callback(new_completed - last_completed) + last_completed = new_completed + + # Check if all complete + if status.status == BatchUploadStatus.COMPLETED: + logger.info( + f"All batches completed: {status.completed_count} succeeded, {status.failed_count} failed" + ) + break + + # Check timeout + elapsed = time.time() - start_time + if elapsed > rapidata_config.upload.batchTimeout: + logger.error( + f"Batch upload timeout after {elapsed:.1f}s (limit: {rapidata_config.upload.batchTimeout}s)" + ) + # Abort batches and return failures + for batch_id in batch_ids: + try: + self.openapi_service.batch_upload_api.asset_batch_upload_batch_upload_id_abort_post( + batch_upload_id=batch_id + ) + except Exception as e: + logger.warning(f"Failed to abort batch {batch_id}: {e}") + return [ + FailedUpload( + item=url, + error_type="BatchUploadTimeout", + error_message=f"Batch upload timed out after {elapsed:.1f}s", + ) + for url in urls + ] + + # Exponential backoff with max interval + time.sleep(poll_interval) + poll_interval = min( + poll_interval * 1.5, rapidata_config.upload.batchPollMaxInterval + ) + + except Exception as e: + logger.error(f"Error polling batch status: {e}") + # Continue polling after error + time.sleep(poll_interval) + + # Fetch results from each batch + logger.debug(f"Fetching results from {len(batch_ids)} batch(es)") + failed_uploads: list[FailedUpload[str]] = [] + successful_count = 0 + + for batch_idx, batch_id in enumerate(batch_ids): + try: + result = self.openapi_service.batch_upload_api.asset_batch_upload_batch_upload_id_get( + batch_upload_id=batch_id + ) + + for item in result.items: + if item.status == BatchUploadUrlStatus.COMPLETED: + # Cache successful upload + cache_key = self._get_url_cache_key(item.url) + self.url_cache._storage[cache_key] = item.file_name + successful_count += 1 + logger.debug( + f"Cached successful upload: {item.url} -> {item.file_name}" + ) + else: + # Track failure + failed_uploads.append( + FailedUpload( + item=item.url, + error_type="BatchUploadFailed", + error_message=item.error_message + or "Unknown batch upload error", + ) + ) + logger.warning( + f"URL failed in batch: {item.url} - {item.error_message}" + ) + + except Exception as e: + logger.error(f"Failed to fetch results for batch {batch_id}: {e}") + # Can't determine which URLs failed in this batch + # Add generic error for the entire batch + failed_uploads.append( + FailedUpload( + item=f"batch_{batch_idx}", + error_type="BatchResultFetchFailed", + error_message=f"Failed to fetch batch results: {str(e)}", + ) + ) + + logger.info( + f"Batch upload complete: {successful_count} succeeded, {len(failed_uploads)} failed" + ) + return failed_uploads + + def _get_url_cache_key(self, url: str) -> str: + """Generate cache key for a URL, including environment.""" + env = self.openapi_service.environment + return f"{env}@{url}" diff --git a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py index fd1b907f3..59911f4f6 100644 --- a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py +++ b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py @@ -1,8 +1,12 @@ from rapidata.rapidata_client.datapoints._datapoint import Datapoint from rapidata.service.openapi_service import OpenAPIService from rapidata.rapidata_client.datapoints._datapoint_uploader import DatapointUploader +from rapidata.rapidata_client.datapoints._asset_upload_orchestrator import ( + AssetUploadOrchestrator, +) from rapidata.rapidata_client.utils.threaded_uploader import ThreadedUploader from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload +from rapidata.rapidata_client.config import rapidata_config class RapidataDataset: @@ -10,21 +14,33 @@ def __init__(self, dataset_id: str, openapi_service: OpenAPIService): self.id = dataset_id self.openapi_service = openapi_service self.datapoint_uploader = DatapointUploader(openapi_service) + self.asset_orchestrator = AssetUploadOrchestrator(openapi_service) def add_datapoints( self, datapoints: list[Datapoint], ) -> tuple[list[Datapoint], list[FailedUpload[Datapoint]]]: """ - Process uploads in chunks with a ThreadPoolExecutor. + Upload datapoints in two steps: + Step 1/2: Upload all assets (throws exception if fails) + Step 2/2: Create datapoints (using cached assets) Args: datapoints: List of datapoints to upload Returns: tuple[list[Datapoint], list[FailedUpload[Datapoint]]]: Lists of successful uploads and failed uploads with error details + + Raises: + AssetUploadException: If any asset uploads fail in Step 1/2 """ + # STEP 1/2: Upload ALL assets + # This will throw AssetUploadException if any uploads fail + if rapidata_config.upload.enableBatchUpload: + self.asset_orchestrator.upload_all_assets(datapoints) + + # STEP 2/2: Create datapoints (all assets already uploaded) def upload_single_datapoint(datapoint: Datapoint, index: int) -> None: self.datapoint_uploader.upload_datapoint( dataset_id=self.id, @@ -34,7 +50,7 @@ def upload_single_datapoint(datapoint: Datapoint, index: int) -> None: uploader: ThreadedUploader[Datapoint] = ThreadedUploader( upload_fn=upload_single_datapoint, - description="Uploading datapoints", + description="Step 2/2: Creating datapoints", ) successful_uploads, failed_uploads = uploader.upload(datapoints) diff --git a/src/rapidata/rapidata_client/exceptions/__init__.py b/src/rapidata/rapidata_client/exceptions/__init__.py index 082165944..86ca12215 100644 --- a/src/rapidata/rapidata_client/exceptions/__init__.py +++ b/src/rapidata/rapidata_client/exceptions/__init__.py @@ -1,5 +1,6 @@ +from .asset_upload_exception import AssetUploadException from .failed_upload import FailedUpload from .failed_upload_exception import FailedUploadException from .rapidata_error import RapidataError -__all__ = ["FailedUpload", "FailedUploadException", "RapidataError"] +__all__ = ["AssetUploadException", "FailedUpload", "FailedUploadException", "RapidataError"] diff --git a/src/rapidata/rapidata_client/exceptions/asset_upload_exception.py b/src/rapidata/rapidata_client/exceptions/asset_upload_exception.py new file mode 100644 index 000000000..077eb25ee --- /dev/null +++ b/src/rapidata/rapidata_client/exceptions/asset_upload_exception.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload + + +class AssetUploadException(Exception): + """ + Exception raised when asset uploads fail in Step 1/2 of the two-step upload process. + + This exception contains details about which assets failed to upload, + allowing users to decide how to proceed (retry, skip, or abort). + + Attributes: + failed_uploads: List of FailedUpload instances containing the failed assets and error details. + """ + + def __init__(self, failed_uploads: list[FailedUpload[str]]): + self.failed_uploads = failed_uploads + message = ( + f"Failed to upload {len(failed_uploads)} asset(s) in Step 1/2. " + f"See failed_uploads attribute for details." + ) + super().__init__(message) + + def __str__(self) -> str: + error_summary = "\n".join( + f" - {fu.item}: {fu.error_message}" for fu in self.failed_uploads[:5] + ) + if len(self.failed_uploads) > 5: + error_summary += f"\n ... and {len(self.failed_uploads) - 5} more" + return f"{super().__str__()}\n{error_summary}" diff --git a/src/rapidata/service/openapi_service.py b/src/rapidata/service/openapi_service.py index 7295eb151..8ad33926e 100644 --- a/src/rapidata/service/openapi_service.py +++ b/src/rapidata/service/openapi_service.py @@ -24,6 +24,7 @@ from rapidata.api_client.api.workflow_api import WorkflowApi from rapidata.api_client.api.participant_api import ParticipantApi from rapidata.api_client.api.audience_api import AudienceApi + from rapidata.api_client.api.batch_upload_api import BatchUploadApi class OpenAPIService: @@ -142,6 +143,12 @@ def asset_api(self) -> AssetApi: return AssetApi(self.api_client) + @property + def batch_upload_api(self) -> BatchUploadApi: + from rapidata.api_client.api.batch_upload_api import BatchUploadApi + + return BatchUploadApi(self.api_client) + @property def dataset_api(self) -> DatasetApi: from rapidata.api_client.api.dataset_api import DatasetApi From e2a6604db2c02c5ef1c369822b648501dcd206d6 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Tue, 27 Jan 2026 13:37:44 +0100 Subject: [PATCH 02/21] improved logic of upload --- .../rapidata_client/config/upload_config.py | 8 +- .../datapoints/_asset_upload_orchestrator.py | 16 +- .../datapoints/_batch_asset_uploader.py | 185 +++++++++++------- .../datapoints/_single_flight_cache.py | 4 + 4 files changed, 132 insertions(+), 81 deletions(-) diff --git a/src/rapidata/rapidata_client/config/upload_config.py b/src/rapidata/rapidata_client/config/upload_config.py index 40c9de7f3..091fa2c61 100644 --- a/src/rapidata/rapidata_client/config/upload_config.py +++ b/src/rapidata/rapidata_client/config/upload_config.py @@ -49,10 +49,6 @@ class UploadConfig(BaseModel): default=0.5, description="Polling interval in seconds", ) - batchPollMaxInterval: float = Field( - default=5.0, - description="Maximum polling interval", - ) batchTimeout: float = Field( default=300.0, description="Batch upload timeout in seconds", @@ -85,9 +81,7 @@ def validate_batch_size(cls, v: int) -> int: if v < 10: raise ValueError("batchSize must be at least 10") if v > 500: - logger.warning( - f"batchSize={v} may cause timeouts. Recommend 50-200." - ) + logger.warning(f"batchSize={v} may cause timeouts. Recommend 50-200.") return v def __init__(self, **kwargs): diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py index 5afd5aa2e..9cbb62db7 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -2,7 +2,7 @@ import re from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import TYPE_CHECKING +from typing import Callable, TYPE_CHECKING from tqdm import tqdm @@ -84,8 +84,12 @@ def upload_all_assets(self, datapoints: list[Datapoint]) -> None: # 4a. Batch upload URLs if uncached_urls: logger.debug(f"Batch uploading {len(uncached_urls)} URL(s)") + + def update_progress(n: int) -> None: + pbar.update(n) + url_failures = self.batch_uploader.batch_upload_urls( - uncached_urls, progress_callback=lambda n: pbar.update(n) + uncached_urls, progress_callback=update_progress ) failed_uploads.extend(url_failures) else: @@ -94,8 +98,12 @@ def upload_all_assets(self, datapoints: list[Datapoint]) -> None: # 4b. Parallel upload files if uncached_files: logger.debug(f"Parallel uploading {len(uncached_files)} file(s)") + + def update_file_progress() -> None: + pbar.update(1) + file_failures = self._upload_files_parallel( - uncached_files, progress_callback=lambda: pbar.update(1) + uncached_files, progress_callback=update_file_progress ) failed_uploads.extend(file_failures) else: @@ -150,7 +158,7 @@ def _filter_uncached( def _upload_files_parallel( self, files: list[str], - progress_callback: callable | None = None, + progress_callback: Callable[[], None] | None = None, ) -> list[FailedUpload[str]]: """ Upload files in parallel using ThreadPoolExecutor. diff --git a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py index a8033b836..badf91df1 100644 --- a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py @@ -14,6 +14,9 @@ if TYPE_CHECKING: from rapidata.service.openapi_service import OpenAPIService + from rapidata.api_client.models.get_batch_upload_status_endpoint_output import ( + GetBatchUploadStatusEndpointOutput, + ) class BatchAssetUploader: @@ -47,14 +50,39 @@ def batch_upload_urls( if not urls: return [] - # Split into batches + # Split and submit batches + batches = self._split_into_batches(urls) + batch_ids = self._submit_batches(batches) + + if not batch_ids: + logger.error("No batches were successfully submitted") + return self._create_submission_failures(urls) + + # Poll until complete + self._poll_until_complete(batch_ids, progress_callback) + + # Fetch and process results + return self._fetch_and_process_results(batch_ids) + + def _split_into_batches(self, urls: list[str]) -> list[list[str]]: + """Split URLs into batches of configured size.""" batch_size = rapidata_config.upload.batchSize batches = [urls[i : i + batch_size] for i in range(0, len(urls), batch_size)] - logger.info(f"Submitting {len(urls)} URLs in {len(batches)} batch(es)") + return batches + + def _submit_batches(self, batches: list[list[str]]) -> list[str]: + """ + Submit all batches to the API. + + Args: + batches: List of URL batches to submit. + + Returns: + List of batch IDs that were successfully submitted. + """ + batch_ids: list[str] = [] - # Submit all batches immediately (parallel) - batch_ids = [] for batch_idx, batch in enumerate(batches): try: result = self.openapi_service.batch_upload_api.asset_batch_upload_post( @@ -68,30 +96,30 @@ def batch_upload_urls( ) except Exception as e: logger.error(f"Failed to submit batch {batch_idx + 1}: {e}") - # Fall back to individual uploads for this batch - for url in batch: - failed_upload = FailedUpload( - item=url, - error_type="BatchSubmissionFailed", - error_message=f"Failed to submit batch: {str(e)}", - ) - # Don't return early - try to submit remaining batches + # Continue trying to submit remaining batches - if not batch_ids: - logger.error("No batches were successfully submitted") - return [ - FailedUpload( - item=url, - error_type="BatchSubmissionFailed", - error_message="Failed to submit any batches", - ) - for url in urls - ] + logger.info(f"Successfully submitted {len(batch_ids)}/{len(batches)} batches") + return batch_ids - # Poll all batches together until complete + def _poll_until_complete( + self, + batch_ids: list[str], + progress_callback: Callable[[int], None] | None, + ) -> None: + """ + Poll batches until all complete. + + Args: + batch_ids: List of batch IDs to poll. + progress_callback: Optional callback to report progress. + """ logger.debug(f"Polling {len(batch_ids)} batch(es) for completion") - last_completed = 0 + + # Scale initial polling interval based on batch count + # More batches = longer expected completion time = less frequent polling poll_interval = rapidata_config.upload.batchPollInterval + + last_completed = 0 start_time = time.time() while True: @@ -103,54 +131,49 @@ def batch_upload_urls( ) # Update progress - if progress_callback: - new_completed = status.completed_count + status.failed_count - if new_completed > last_completed: - progress_callback(new_completed - last_completed) - last_completed = new_completed + self._update_progress(status, last_completed, progress_callback) + last_completed = status.completed_count + status.failed_count - # Check if all complete + # Check completion if status.status == BatchUploadStatus.COMPLETED: + elapsed = time.time() - start_time logger.info( - f"All batches completed: {status.completed_count} succeeded, {status.failed_count} failed" + f"All batches completed in {elapsed:.1f}s: " + f"{status.completed_count} succeeded, {status.failed_count} failed" ) - break + return - # Check timeout - elapsed = time.time() - start_time - if elapsed > rapidata_config.upload.batchTimeout: - logger.error( - f"Batch upload timeout after {elapsed:.1f}s (limit: {rapidata_config.upload.batchTimeout}s)" - ) - # Abort batches and return failures - for batch_id in batch_ids: - try: - self.openapi_service.batch_upload_api.asset_batch_upload_batch_upload_id_abort_post( - batch_upload_id=batch_id - ) - except Exception as e: - logger.warning(f"Failed to abort batch {batch_id}: {e}") - return [ - FailedUpload( - item=url, - error_type="BatchUploadTimeout", - error_message=f"Batch upload timed out after {elapsed:.1f}s", - ) - for url in urls - ] - - # Exponential backoff with max interval + # Wait before next poll (exponential backoff) time.sleep(poll_interval) - poll_interval = min( - poll_interval * 1.5, rapidata_config.upload.batchPollMaxInterval - ) except Exception as e: logger.error(f"Error polling batch status: {e}") - # Continue polling after error time.sleep(poll_interval) - # Fetch results from each batch + def _update_progress( + self, + status: GetBatchUploadStatusEndpointOutput, + last_completed: int, + progress_callback: Callable[[int], None] | None, + ) -> None: + """Update progress callback if provided.""" + if progress_callback: + new_completed = status.completed_count + status.failed_count + if new_completed > last_completed: + progress_callback(new_completed - last_completed) + + def _fetch_and_process_results( + self, batch_ids: list[str] + ) -> list[FailedUpload[str]]: + """ + Fetch results from all batches and process them. + + Args: + batch_ids: List of batch IDs to fetch results from. + + Returns: + List of failed uploads. + """ logger.debug(f"Fetching results from {len(batch_ids)} batch(es)") failed_uploads: list[FailedUpload[str]] = [] successful_count = 0 @@ -161,15 +184,28 @@ def batch_upload_urls( batch_upload_id=batch_id ) + # Process each URL in the batch result for item in result.items: if item.status == BatchUploadUrlStatus.COMPLETED: - # Cache successful upload - cache_key = self._get_url_cache_key(item.url) - self.url_cache._storage[cache_key] = item.file_name - successful_count += 1 - logger.debug( - f"Cached successful upload: {item.url} -> {item.file_name}" - ) + # Cache successful upload using proper API + if item.file_name is not None: + cache_key = self._get_url_cache_key(item.url) + self.url_cache.set(cache_key, item.file_name) + successful_count += 1 + logger.debug( + f"Cached successful upload: {item.url} -> {item.file_name}" + ) + else: + logger.warning( + f"Batch upload completed but file_name is None for URL: {item.url}" + ) + failed_uploads.append( + FailedUpload( + item=item.url, + error_type="BatchUploadFailed", + error_message="Upload completed but file_name is None", + ) + ) else: # Track failure failed_uploads.append( @@ -186,8 +222,6 @@ def batch_upload_urls( except Exception as e: logger.error(f"Failed to fetch results for batch {batch_id}: {e}") - # Can't determine which URLs failed in this batch - # Add generic error for the entire batch failed_uploads.append( FailedUpload( item=f"batch_{batch_idx}", @@ -201,6 +235,17 @@ def batch_upload_urls( ) return failed_uploads + def _create_submission_failures(self, urls: list[str]) -> list[FailedUpload[str]]: + """Create FailedUpload instances for all URLs when submission fails.""" + return [ + FailedUpload( + item=url, + error_type="BatchSubmissionFailed", + error_message="Failed to submit any batches", + ) + for url in urls + ] + def _get_url_cache_key(self, url: str) -> str: """Generate cache key for a URL, including environment.""" env = self.openapi_service.environment diff --git a/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py b/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py index c7ae48cd5..8d13b7788 100644 --- a/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py +++ b/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py @@ -77,6 +77,10 @@ def get_or_fetch( with self._lock: self._in_flight.pop(key, None) + def set(self, key: str, value: str) -> None: + """Set a value in the cache.""" + self._storage[key] = value + def clear(self) -> None: """Clear the cache.""" self._storage.clear() From 02b01c9de9fca8852dd0828b258c25d0bb89da57 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Tue, 27 Jan 2026 14:54:31 +0100 Subject: [PATCH 03/21] added checking of parameters on upload config and removed order config as no longer used --- .../rapidata_client/config/order_config.py | 14 -------------- .../rapidata_client/config/rapidata_config.py | 4 ---- .../rapidata_client/config/upload_config.py | 4 +++- 3 files changed, 3 insertions(+), 19 deletions(-) delete mode 100644 src/rapidata/rapidata_client/config/order_config.py diff --git a/src/rapidata/rapidata_client/config/order_config.py b/src/rapidata/rapidata_client/config/order_config.py deleted file mode 100644 index c9990507d..000000000 --- a/src/rapidata/rapidata_client/config/order_config.py +++ /dev/null @@ -1,14 +0,0 @@ -from pydantic import BaseModel, Field - - -class OrderConfig(BaseModel): - """ - Holds the configuration for the order process. - - Attributes: - minOrderDatapointsForValidation (int): The minimum number of datapoints required so that an automatic validationset gets created if no recommended was found. Defaults to 50. - autoValidationSetSize (int): The maximum size of the auto-generated validation set. Defaults to 20. - """ - - autoValidationSetCreation: bool = Field(default=True) - minOrderDatapointsForValidation: int = Field(default=20) diff --git a/src/rapidata/rapidata_client/config/rapidata_config.py b/src/rapidata/rapidata_client/config/rapidata_config.py index 02de167fd..1d2fedfce 100644 --- a/src/rapidata/rapidata_client/config/rapidata_config.py +++ b/src/rapidata/rapidata_client/config/rapidata_config.py @@ -1,7 +1,6 @@ from pydantic import BaseModel, Field from rapidata.rapidata_client.config.logging_config import LoggingConfig -from rapidata.rapidata_client.config.order_config import OrderConfig from rapidata.rapidata_client.config.upload_config import UploadConfig @@ -15,8 +14,6 @@ class RapidataConfig(BaseModel): enableBetaFeatures (bool): Whether to enable beta features. Defaults to False. upload (UploadConfig): The configuration for the upload process. Such as the maximum number of worker threads for processing media paths and the maximum number of retries for failed uploads. - order (OrderConfig): The configuration for the order process. - Such as the minimum number of datapoints required so that an automatic validationset gets created if no recommended was found. logging (LoggingConfig): The configuration for the logging process. Such as the logging level and the logging file. @@ -29,7 +26,6 @@ class RapidataConfig(BaseModel): enableBetaFeatures: bool = False upload: UploadConfig = Field(default_factory=UploadConfig) - order: OrderConfig = Field(default_factory=OrderConfig) logging: LoggingConfig = Field(default_factory=LoggingConfig) diff --git a/src/rapidata/rapidata_client/config/upload_config.py b/src/rapidata/rapidata_client/config/upload_config.py index 091fa2c61..6301190c6 100644 --- a/src/rapidata/rapidata_client/config/upload_config.py +++ b/src/rapidata/rapidata_client/config/upload_config.py @@ -1,6 +1,6 @@ from pathlib import Path import shutil -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from rapidata.rapidata_client.config import logger @@ -25,6 +25,8 @@ class UploadConfig(BaseModel): batchTimeout (float): Batch upload timeout in seconds. Defaults to 300.0. """ + model_config = ConfigDict(validate_assignment=True) + maxWorkers: int = Field(default=25) maxRetries: int = Field(default=3) cacheUploads: bool = Field(default=True) From 770c3328f53e8e2382cdfa6e448a9c22e6d805bb Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Tue, 27 Jan 2026 16:43:04 +0100 Subject: [PATCH 04/21] changed tqdm to be tqdm.auto --- .../benchmark/participant/_participant.py | 2 +- .../datapoints/_asset_upload_orchestrator.py | 10 +++++----- src/rapidata/rapidata_client/job/rapidata_job.py | 2 +- src/rapidata/rapidata_client/order/rapidata_order.py | 6 ++++-- .../rapidata_client/utils/threaded_uploader.py | 2 +- .../validation/validation_set_manager.py | 2 +- src/rapidata/types/__init__.py | 2 -- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/rapidata/rapidata_client/benchmark/participant/_participant.py b/src/rapidata/rapidata_client/benchmark/participant/_participant.py index 1247043f8..999aff00a 100644 --- a/src/rapidata/rapidata_client/benchmark/participant/_participant.py +++ b/src/rapidata/rapidata_client/benchmark/participant/_participant.py @@ -2,7 +2,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import time -from tqdm import tqdm +from tqdm.auto import tqdm from rapidata.rapidata_client.config import logger from rapidata.rapidata_client.config.rapidata_config import rapidata_config diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py index 9cbb62db7..3f171a66c 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Callable, TYPE_CHECKING -from tqdm import tqdm +from tqdm.auto import tqdm from rapidata.rapidata_client.config import logger, rapidata_config from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader @@ -132,16 +132,16 @@ def _extract_unique_assets(self, datapoints: list[Datapoint]) -> set[str]: assets.add(dp.media_context) return assets - def _filter_uncached( - self, assets: list[str], cache - ) -> list[str]: + def _filter_uncached(self, assets: list[str], cache) -> list[str]: """Filter out assets that are already cached.""" uncached = [] for asset in assets: try: # Try to get cache key if re.match(r"^https?://", asset): - cache_key = f"{self.asset_uploader.openapi_service.environment}@{asset}" + cache_key = ( + f"{self.asset_uploader.openapi_service.environment}@{asset}" + ) else: cache_key = self.asset_uploader._get_file_cache_key(asset) diff --git a/src/rapidata/rapidata_client/job/rapidata_job.py b/src/rapidata/rapidata_client/job/rapidata_job.py index fb57247d4..632819124 100644 --- a/src/rapidata/rapidata_client/job/rapidata_job.py +++ b/src/rapidata/rapidata_client/job/rapidata_job.py @@ -7,7 +7,7 @@ from time import sleep from typing import Callable, TypeVar, TYPE_CHECKING from colorama import Fore -from tqdm import tqdm +from tqdm.auto import tqdm from rapidata.service.openapi_service import OpenAPIService from rapidata.rapidata_client.config import ( diff --git a/src/rapidata/rapidata_client/order/rapidata_order.py b/src/rapidata/rapidata_client/order/rapidata_order.py index 60a00d473..4f272df91 100644 --- a/src/rapidata/rapidata_client/order/rapidata_order.py +++ b/src/rapidata/rapidata_client/order/rapidata_order.py @@ -9,7 +9,7 @@ from typing import cast, Callable, TypeVar, TYPE_CHECKING from colorama import Fore from datetime import datetime -from tqdm import tqdm +from tqdm.auto import tqdm # Local/application imports from rapidata.service.openapi_service import OpenAPIService @@ -336,7 +336,9 @@ def get_results(self, preliminary_results: bool = False) -> RapidataResults: with tracer.start_as_current_span("RapidataOrder.get_results"): from rapidata.api_client.models.order_state import OrderState from rapidata.api_client.exceptions import ApiException - from rapidata.rapidata_client.results.rapidata_results import RapidataResults + from rapidata.rapidata_client.results.rapidata_results import ( + RapidataResults, + ) logger.info("Getting results for order '%s'...", self) diff --git a/src/rapidata/rapidata_client/utils/threaded_uploader.py b/src/rapidata/rapidata_client/utils/threaded_uploader.py index b326c6014..a832f9495 100644 --- a/src/rapidata/rapidata_client/utils/threaded_uploader.py +++ b/src/rapidata/rapidata_client/utils/threaded_uploader.py @@ -1,6 +1,6 @@ from typing import TypeVar, Generic, Callable from concurrent.futures import ThreadPoolExecutor, as_completed -from tqdm import tqdm +from tqdm.auto import tqdm import time from rapidata.rapidata_client.config import logger diff --git a/src/rapidata/rapidata_client/validation/validation_set_manager.py b/src/rapidata/rapidata_client/validation/validation_set_manager.py index 4ce03d371..f58bcd5a4 100644 --- a/src/rapidata/rapidata_client/validation/validation_set_manager.py +++ b/src/rapidata/rapidata_client/validation/validation_set_manager.py @@ -30,7 +30,7 @@ rapidata_config, tracer, ) -from tqdm import tqdm +from tqdm.auto import tqdm from rapidata.rapidata_client.validation.rapids.rapids import Rapid if TYPE_CHECKING: diff --git a/src/rapidata/types/__init__.py b/src/rapidata/types/__init__.py index b7079e44f..481e03455 100644 --- a/src/rapidata/types/__init__.py +++ b/src/rapidata/types/__init__.py @@ -71,7 +71,6 @@ # Configuration Types -from rapidata.rapidata_client.config.order_config import OrderConfig from rapidata.rapidata_client.config.upload_config import UploadConfig from rapidata.rapidata_client.config.rapidata_config import RapidataConfig from rapidata.rapidata_client.config.logging_config import LoggingConfig @@ -136,7 +135,6 @@ "TranslationBehaviour", "TranslationBehaviourOptions", # Configuration Types - "OrderConfig", "UploadConfig", "RapidataConfig", "LoggingConfig", From 191b58341cbb8302feb4a33a3ce962232ab259ec Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:13:40 +0100 Subject: [PATCH 05/21] initial test with concurrent datapoint creation --- .../rapidata_client/config/upload_config.py | 4 - .../datapoints/_asset_upload_orchestrator.py | 22 ++- .../datapoints/_batch_asset_uploader.py | 149 ++++++++++-------- .../dataset/_rapidata_dataset.py | 129 ++++++++++++--- 4 files changed, 209 insertions(+), 95 deletions(-) diff --git a/src/rapidata/rapidata_client/config/upload_config.py b/src/rapidata/rapidata_client/config/upload_config.py index 6301190c6..c951526c8 100644 --- a/src/rapidata/rapidata_client/config/upload_config.py +++ b/src/rapidata/rapidata_client/config/upload_config.py @@ -39,10 +39,6 @@ class UploadConfig(BaseModel): default=128, frozen=True, ) - enableBatchUpload: bool = Field( - default=True, - description="Enable batch URL uploading (two-step process)", - ) batchSize: int = Field( default=100, description="Number of URLs per batch (10-500)", diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py index 3f171a66c..17e12f586 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -36,13 +36,18 @@ def __init__(self, openapi_service: OpenAPIService): self.asset_uploader = AssetUploader(openapi_service) self.batch_uploader = BatchAssetUploader(openapi_service) - def upload_all_assets(self, datapoints: list[Datapoint]) -> None: + def upload_all_assets( + self, + datapoints: list[Datapoint], + asset_completion_callback: Callable[[list[str]], None] | None = None, + ) -> None: """ Step 1/2: Upload ALL assets from ALL datapoints. Throws AssetUploadException if any uploads fail. Args: datapoints: List of datapoints to extract assets from. + asset_completion_callback: Optional callback to notify when assets complete (called with list of successful assets). Raises: AssetUploadException: If any asset uploads fail. @@ -89,7 +94,9 @@ def update_progress(n: int) -> None: pbar.update(n) url_failures = self.batch_uploader.batch_upload_urls( - uncached_urls, progress_callback=update_progress + uncached_urls, + progress_callback=update_progress, + completion_callback=asset_completion_callback, ) failed_uploads.extend(url_failures) else: @@ -103,7 +110,9 @@ def update_file_progress() -> None: pbar.update(1) file_failures = self._upload_files_parallel( - uncached_files, progress_callback=update_file_progress + uncached_files, + progress_callback=update_file_progress, + completion_callback=asset_completion_callback, ) failed_uploads.extend(file_failures) else: @@ -159,6 +168,7 @@ def _upload_files_parallel( self, files: list[str], progress_callback: Callable[[], None] | None = None, + completion_callback: Callable[[list[str]], None] | None = None, ) -> list[FailedUpload[str]]: """ Upload files in parallel using ThreadPoolExecutor. @@ -166,6 +176,7 @@ def _upload_files_parallel( Args: files: List of file paths to upload. progress_callback: Optional callback to report progress (called once per completed file). + completion_callback: Optional callback to notify when files complete (called with list of successful files). Returns: List of FailedUpload instances for any files that failed. @@ -190,9 +201,14 @@ def upload_single_file(file_path: str) -> FailedUpload[str] | None: } for future in as_completed(futures): + file_path = futures[future] result = future.result() if result is not None: failed_uploads.append(result) + else: + # File uploaded successfully, notify callback + if completion_callback: + completion_callback([file_path]) if progress_callback: progress_callback() diff --git a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py index badf91df1..11619b6e8 100644 --- a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py @@ -35,6 +35,7 @@ def batch_upload_urls( self, urls: list[str], progress_callback: Callable[[int], None] | None = None, + completion_callback: Callable[[list[str]], None] | None = None, ) -> list[FailedUpload[str]]: """ Upload URLs in batches. Returns list of failed uploads. @@ -43,6 +44,7 @@ def batch_upload_urls( Args: urls: List of URLs to upload. progress_callback: Optional callback to report progress (called with number of newly completed items). + completion_callback: Optional callback to notify when URLs complete (called with list of successful URLs). Returns: List of FailedUpload instances for any URLs that failed. @@ -52,17 +54,14 @@ def batch_upload_urls( # Split and submit batches batches = self._split_into_batches(urls) - batch_ids = self._submit_batches(batches) + batch_ids, batch_to_urls = self._submit_batches(batches) if not batch_ids: logger.error("No batches were successfully submitted") return self._create_submission_failures(urls) # Poll until complete - self._poll_until_complete(batch_ids, progress_callback) - - # Fetch and process results - return self._fetch_and_process_results(batch_ids) + return self._poll_until_complete(batch_ids, batch_to_urls, progress_callback, completion_callback) def _split_into_batches(self, urls: list[str]) -> list[list[str]]: """Split URLs into batches of configured size.""" @@ -71,7 +70,7 @@ def _split_into_batches(self, urls: list[str]) -> list[list[str]]: logger.info(f"Submitting {len(urls)} URLs in {len(batches)} batch(es)") return batches - def _submit_batches(self, batches: list[list[str]]) -> list[str]: + def _submit_batches(self, batches: list[list[str]]) -> tuple[list[str], dict[str, list[str]]]: """ Submit all batches to the API. @@ -79,9 +78,10 @@ def _submit_batches(self, batches: list[list[str]]) -> list[str]: batches: List of URL batches to submit. Returns: - List of batch IDs that were successfully submitted. + Tuple of (batch_ids, batch_to_urls) where batch_to_urls maps batch_id to its URL list. """ batch_ids: list[str] = [] + batch_to_urls: dict[str, list[str]] = {} for batch_idx, batch in enumerate(batches): try: @@ -90,28 +90,37 @@ def _submit_batches(self, batches: list[list[str]]) -> list[str]: urls=batch ) ) - batch_ids.append(result.batch_upload_id) + batch_id = result.batch_upload_id + batch_ids.append(batch_id) + batch_to_urls[batch_id] = batch logger.debug( - f"Submitted batch {batch_idx + 1}/{len(batches)}: {result.batch_upload_id}" + f"Submitted batch {batch_idx + 1}/{len(batches)}: {batch_id}" ) except Exception as e: logger.error(f"Failed to submit batch {batch_idx + 1}: {e}") # Continue trying to submit remaining batches logger.info(f"Successfully submitted {len(batch_ids)}/{len(batches)} batches") - return batch_ids + return batch_ids, batch_to_urls def _poll_until_complete( self, batch_ids: list[str], + batch_to_urls: dict[str, list[str]], progress_callback: Callable[[int], None] | None, - ) -> None: + completion_callback: Callable[[list[str]], None] | None, + ) -> list[FailedUpload[str]]: """ - Poll batches until all complete. + Poll batches until all complete. Process batches incrementally as they complete. Args: batch_ids: List of batch IDs to poll. + batch_to_urls: Mapping from batch_id to list of URLs in that batch. progress_callback: Optional callback to report progress. + completion_callback: Optional callback to notify when URLs complete. + + Returns: + List of FailedUpload instances for any URLs that failed. """ logger.debug(f"Polling {len(batch_ids)} batch(es) for completion") @@ -121,6 +130,8 @@ def _poll_until_complete( last_completed = 0 start_time = time.time() + processed_batches: set[str] = set() + all_failures: list[FailedUpload[str]] = [] while True: try: @@ -130,6 +141,17 @@ def _poll_until_complete( ) ) + # Process newly completed batches + for batch_id in status.completed_batches: + if batch_id not in processed_batches: + successful_urls, failures = self._process_single_batch(batch_id) + processed_batches.add(batch_id) + all_failures.extend(failures) + + # Notify callback with completed URLs + if completion_callback and successful_urls: + completion_callback(successful_urls) + # Update progress self._update_progress(status, last_completed, progress_callback) last_completed = status.completed_count + status.failed_count @@ -141,7 +163,7 @@ def _poll_until_complete( f"All batches completed in {elapsed:.1f}s: " f"{status.completed_count} succeeded, {status.failed_count} failed" ) - return + return all_failures # Wait before next poll (exponential backoff) time.sleep(poll_interval) @@ -162,78 +184,75 @@ def _update_progress( if new_completed > last_completed: progress_callback(new_completed - last_completed) - def _fetch_and_process_results( - self, batch_ids: list[str] - ) -> list[FailedUpload[str]]: + def _process_single_batch(self, batch_id: str) -> tuple[list[str], list[FailedUpload[str]]]: """ - Fetch results from all batches and process them. + Fetch and cache results for a single batch. Args: - batch_ids: List of batch IDs to fetch results from. + batch_id: The batch ID to process. Returns: - List of failed uploads. + Tuple of (successful_urls, failed_uploads). """ - logger.debug(f"Fetching results from {len(batch_ids)} batch(es)") + successful_urls: list[str] = [] failed_uploads: list[FailedUpload[str]] = [] - successful_count = 0 - for batch_idx, batch_id in enumerate(batch_ids): - try: - result = self.openapi_service.batch_upload_api.asset_batch_upload_batch_upload_id_get( - batch_upload_id=batch_id - ) + try: + result = self.openapi_service.batch_upload_api.asset_batch_upload_batch_upload_id_get( + batch_upload_id=batch_id + ) - # Process each URL in the batch result - for item in result.items: - if item.status == BatchUploadUrlStatus.COMPLETED: - # Cache successful upload using proper API - if item.file_name is not None: - cache_key = self._get_url_cache_key(item.url) - self.url_cache.set(cache_key, item.file_name) - successful_count += 1 - logger.debug( - f"Cached successful upload: {item.url} -> {item.file_name}" - ) - else: - logger.warning( - f"Batch upload completed but file_name is None for URL: {item.url}" - ) - failed_uploads.append( - FailedUpload( - item=item.url, - error_type="BatchUploadFailed", - error_message="Upload completed but file_name is None", - ) - ) + # Process each URL in the batch result + for item in result.items: + if item.status == BatchUploadUrlStatus.COMPLETED: + # Cache successful upload using proper API + if item.file_name is not None: + cache_key = self._get_url_cache_key(item.url) + self.url_cache.set(cache_key, item.file_name) + successful_urls.append(item.url) + logger.debug( + f"Cached successful upload: {item.url} -> {item.file_name}" + ) else: - # Track failure + logger.warning( + f"Batch upload completed but file_name is None for URL: {item.url}" + ) failed_uploads.append( FailedUpload( item=item.url, error_type="BatchUploadFailed", - error_message=item.error_message - or "Unknown batch upload error", + error_message="Upload completed but file_name is None", ) ) - logger.warning( - f"URL failed in batch: {item.url} - {item.error_message}" + else: + # Track failure + failed_uploads.append( + FailedUpload( + item=item.url, + error_type="BatchUploadFailed", + error_message=item.error_message + or "Unknown batch upload error", ) - - except Exception as e: - logger.error(f"Failed to fetch results for batch {batch_id}: {e}") - failed_uploads.append( - FailedUpload( - item=f"batch_{batch_idx}", - error_type="BatchResultFetchFailed", - error_message=f"Failed to fetch batch results: {str(e)}", ) + logger.warning( + f"URL failed in batch: {item.url} - {item.error_message}" + ) + + except Exception as e: + logger.error(f"Failed to fetch results for batch {batch_id}: {e}") + failed_uploads.append( + FailedUpload( + item=f"batch_{batch_id}", + error_type="BatchResultFetchFailed", + error_message=f"Failed to fetch batch results: {str(e)}", ) + ) + + if successful_urls: + logger.debug(f"Batch {batch_id}: {len(successful_urls)} succeeded, {len(failed_uploads)} failed") + + return successful_urls, failed_uploads - logger.info( - f"Batch upload complete: {successful_count} succeeded, {len(failed_uploads)} failed" - ) - return failed_uploads def _create_submission_failures(self, urls: list[str]) -> list[FailedUpload[str]]: """Create FailedUpload instances for all URLs when submission fails.""" diff --git a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py index 59911f4f6..9b699ea4d 100644 --- a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py +++ b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py @@ -1,12 +1,16 @@ +from __future__ import annotations + +import threading +from concurrent.futures import ThreadPoolExecutor, Future + from rapidata.rapidata_client.datapoints._datapoint import Datapoint from rapidata.service.openapi_service import OpenAPIService from rapidata.rapidata_client.datapoints._datapoint_uploader import DatapointUploader from rapidata.rapidata_client.datapoints._asset_upload_orchestrator import ( AssetUploadOrchestrator, ) -from rapidata.rapidata_client.utils.threaded_uploader import ThreadedUploader from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload -from rapidata.rapidata_client.config import rapidata_config +from rapidata.rapidata_client.config import rapidata_config, logger class RapidataDataset: @@ -21,9 +25,10 @@ def add_datapoints( datapoints: list[Datapoint], ) -> tuple[list[Datapoint], list[FailedUpload[Datapoint]]]: """ - Upload datapoints in two steps: - Step 1/2: Upload all assets (throws exception if fails) - Step 2/2: Create datapoints (using cached assets) + Upload datapoints with incremental creation: + - Start uploading all assets (URLs in batches + files in parallel) + - As assets complete, check which datapoints are ready and create them + - Continue until all uploads and datapoint creation complete Args: datapoints: List of datapoints to upload @@ -32,29 +37,107 @@ def add_datapoints( tuple[list[Datapoint], list[FailedUpload[Datapoint]]]: Lists of successful uploads and failed uploads with error details Raises: - AssetUploadException: If any asset uploads fail in Step 1/2 + AssetUploadException: If any asset uploads fail and prevent datapoint creation """ + if not datapoints: + return [], [] + + # 1. Build mapping: datapoint_index -> required_assets + datapoint_assets: dict[int, set[str]] = {} + for idx, dp in enumerate(datapoints): + assets = set() + if isinstance(dp.asset, list): + assets.update(dp.asset) + else: + assets.add(dp.asset) + if dp.media_context: + assets.add(dp.media_context) + datapoint_assets[idx] = assets + + logger.debug(f"Mapped {len(datapoints)} datapoints to their required assets") + + # 2. Track state (thread-safe) + completed_assets: set[str] = set() + pending_datapoints: set[int] = set(range(len(datapoints))) + creation_futures: list[tuple[int, Future]] = [] + lock = threading.Lock() + + # 3. Create executor for datapoint creation + executor = ThreadPoolExecutor(max_workers=rapidata_config.upload.maxWorkers) + + # 4. Define callback for asset completion + def on_assets_complete(assets: list[str]) -> None: + """Called when a batch of assets completes uploading.""" + with lock: + completed_assets.update(assets) - # STEP 1/2: Upload ALL assets - # This will throw AssetUploadException if any uploads fail - if rapidata_config.upload.enableBatchUpload: - self.asset_orchestrator.upload_all_assets(datapoints) - - # STEP 2/2: Create datapoints (all assets already uploaded) - def upload_single_datapoint(datapoint: Datapoint, index: int) -> None: - self.datapoint_uploader.upload_datapoint( - dataset_id=self.id, - datapoint=datapoint, - index=index, - ) - - uploader: ThreadedUploader[Datapoint] = ThreadedUploader( - upload_fn=upload_single_datapoint, - description="Step 2/2: Creating datapoints", + # Find newly ready datapoints + ready_datapoints = [] + for idx in list(pending_datapoints): + required = datapoint_assets[idx] + if required.issubset(completed_assets): + pending_datapoints.remove(idx) + ready_datapoints.append(idx) + + # Submit ready datapoints for creation (outside lock to avoid blocking) + for idx in ready_datapoints: + future = executor.submit( + self.datapoint_uploader.upload_datapoint, + dataset_id=self.id, + datapoint=datapoints[idx], + index=idx, + ) + with lock: + creation_futures.append((idx, future)) + + if ready_datapoints: + logger.debug( + f"Asset batch completed, {len(ready_datapoints)} datapoints now ready for creation" + ) + + # 5. Start uploads (blocking, but triggers callbacks as assets complete) + logger.info("Starting incremental datapoint creation") + self.asset_orchestrator.upload_all_assets( + datapoints, asset_completion_callback=on_assets_complete ) - successful_uploads, failed_uploads = uploader.upload(datapoints) + # 6. Wait for all datapoint creation to complete + executor.shutdown(wait=True) + logger.debug("All datapoint creation tasks completed") + + # 7. Collect results + successful_uploads: list[Datapoint] = [] + failed_uploads: list[FailedUpload[Datapoint]] = [] + for idx, future in creation_futures: + try: + future.result() # Raises exception if failed + successful_uploads.append(datapoints[idx]) + except Exception as e: + logger.warning(f"Failed to create datapoint {idx}: {e}") + failed_uploads.append( + FailedUpload( + item=datapoints[idx], + error_type="DatapointCreationFailed", + error_message=str(e), + ) + ) + + # 8. Handle datapoints whose assets failed to upload + with lock: + for idx in pending_datapoints: + logger.warning(f"Datapoint {idx} assets failed to upload") + failed_uploads.append( + FailedUpload( + item=datapoints[idx], + error_type="AssetUploadFailed", + error_message="One or more required assets failed to upload", + ) + ) + + logger.info( + f"Datapoint creation complete: {len(successful_uploads)} succeeded, {len(failed_uploads)} failed" + ) return successful_uploads, failed_uploads def __str__(self) -> str: From ecf04331da132c8552f02f8cc7c5baf9a3305f30 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:31:21 +0100 Subject: [PATCH 06/21] added double progressbar --- .../datapoints/_asset_upload_orchestrator.py | 2 + .../dataset/_rapidata_dataset.py | 41 ++++++++++++++----- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py index 17e12f586..aa46d8633 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -84,7 +84,9 @@ def upload_all_assets( with tqdm( total=total, desc="Step 1/2: Uploading assets", + position=0, disable=rapidata_config.logging.silent_mode, + leave=True, ) as pbar: # 4a. Batch upload URLs if uncached_urls: diff --git a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py index 9b699ea4d..092e6130d 100644 --- a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py +++ b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py @@ -3,6 +3,8 @@ import threading from concurrent.futures import ThreadPoolExecutor, Future +from tqdm.auto import tqdm + from rapidata.rapidata_client.datapoints._datapoint import Datapoint from rapidata.service.openapi_service import OpenAPIService from rapidata.rapidata_client.datapoints._datapoint_uploader import DatapointUploader @@ -65,7 +67,16 @@ def add_datapoints( # 3. Create executor for datapoint creation executor = ThreadPoolExecutor(max_workers=rapidata_config.upload.maxWorkers) - # 4. Define callback for asset completion + # 4. Create progress bar for datapoint creation (position=1 to show below asset upload bar) + datapoint_pbar = tqdm( + total=len(datapoints), + desc="Step 2/2: Creating datapoints", + position=1, + disable=rapidata_config.logging.silent_mode, + leave=True, + ) + + # 5. Define callback for asset completion def on_assets_complete(assets: list[str]) -> None: """Called when a batch of assets completes uploading.""" with lock: @@ -81,12 +92,19 @@ def on_assets_complete(assets: list[str]) -> None: # Submit ready datapoints for creation (outside lock to avoid blocking) for idx in ready_datapoints: - future = executor.submit( - self.datapoint_uploader.upload_datapoint, - dataset_id=self.id, - datapoint=datapoints[idx], - index=idx, - ) + + def upload_and_update(dp_idx): + """Upload datapoint and update progress bar when done.""" + try: + self.datapoint_uploader.upload_datapoint( + dataset_id=self.id, + datapoint=datapoints[dp_idx], + index=dp_idx, + ) + finally: + datapoint_pbar.update(1) + + future = executor.submit(upload_and_update, idx) with lock: creation_futures.append((idx, future)) @@ -95,17 +113,18 @@ def on_assets_complete(assets: list[str]) -> None: f"Asset batch completed, {len(ready_datapoints)} datapoints now ready for creation" ) - # 5. Start uploads (blocking, but triggers callbacks as assets complete) + # 6. Start uploads (blocking, but triggers callbacks as assets complete) logger.info("Starting incremental datapoint creation") self.asset_orchestrator.upload_all_assets( datapoints, asset_completion_callback=on_assets_complete ) - # 6. Wait for all datapoint creation to complete + # 7. Wait for all datapoint creation to complete executor.shutdown(wait=True) + datapoint_pbar.close() logger.debug("All datapoint creation tasks completed") - # 7. Collect results + # 8. Collect results successful_uploads: list[Datapoint] = [] failed_uploads: list[FailedUpload[Datapoint]] = [] @@ -123,7 +142,7 @@ def on_assets_complete(assets: list[str]) -> None: ) ) - # 8. Handle datapoints whose assets failed to upload + # 9. Handle datapoints whose assets failed to upload with lock: for idx in pending_datapoints: logger.warning(f"Datapoint {idx} assets failed to upload") From 2d03906b50ff330fb9036ae632432016c0b146fd Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:39:50 +0100 Subject: [PATCH 07/21] improved speed on checking assets --- .../dataset/_rapidata_dataset.py | 48 +++++++++++++------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py index 092e6130d..2c2633928 100644 --- a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py +++ b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py @@ -44,8 +44,11 @@ def add_datapoints( if not datapoints: return [], [] - # 1. Build mapping: datapoint_index -> required_assets - datapoint_assets: dict[int, set[str]] = {} + # 1. Build efficient reverse mapping: asset -> datapoint indices that need it + # This allows O(1) lookup instead of checking all pending datapoints + asset_to_datapoints: dict[str, set[int]] = {} + datapoint_pending_count: dict[int, int] = {} # How many assets each datapoint still needs + for idx, dp in enumerate(datapoints): assets = set() if isinstance(dp.asset, list): @@ -54,13 +57,19 @@ def add_datapoints( assets.add(dp.asset) if dp.media_context: assets.add(dp.media_context) - datapoint_assets[idx] = assets + + # Track how many assets this datapoint needs + datapoint_pending_count[idx] = len(assets) + + # Build reverse mapping + for asset in assets: + if asset not in asset_to_datapoints: + asset_to_datapoints[asset] = set() + asset_to_datapoints[asset].add(idx) logger.debug(f"Mapped {len(datapoints)} datapoints to their required assets") # 2. Track state (thread-safe) - completed_assets: set[str] = set() - pending_datapoints: set[int] = set(range(len(datapoints))) creation_futures: list[tuple[int, Future]] = [] lock = threading.Lock() @@ -79,16 +88,25 @@ def add_datapoints( # 5. Define callback for asset completion def on_assets_complete(assets: list[str]) -> None: """Called when a batch of assets completes uploading.""" - with lock: - completed_assets.update(assets) + ready_datapoints = [] - # Find newly ready datapoints - ready_datapoints = [] - for idx in list(pending_datapoints): - required = datapoint_assets[idx] - if required.issubset(completed_assets): - pending_datapoints.remove(idx) - ready_datapoints.append(idx) + with lock: + # For each completed asset, find datapoints that need it + for asset in assets: + if asset in asset_to_datapoints: + # Get all datapoints waiting for this asset + for idx in list(asset_to_datapoints[asset]): + if idx in datapoint_pending_count: + # Decrement the count + datapoint_pending_count[idx] -= 1 + + # If all assets are ready, mark for creation + if datapoint_pending_count[idx] == 0: + ready_datapoints.append(idx) + del datapoint_pending_count[idx] + + # Remove this datapoint from this asset's waiting list + asset_to_datapoints[asset].discard(idx) # Submit ready datapoints for creation (outside lock to avoid blocking) for idx in ready_datapoints: @@ -144,7 +162,7 @@ def upload_and_update(dp_idx): # 9. Handle datapoints whose assets failed to upload with lock: - for idx in pending_datapoints: + for idx in datapoint_pending_count: logger.warning(f"Datapoint {idx} assets failed to upload") failed_uploads.append( FailedUpload( From 7f72e01e42db3bd7eb2fb06288285c605b15ec37 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:53:42 +0100 Subject: [PATCH 08/21] moved update of progress bar --- .../datapoints/_batch_asset_uploader.py | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py index 11619b6e8..8d1d2cfa2 100644 --- a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py @@ -128,7 +128,6 @@ def _poll_until_complete( # More batches = longer expected completion time = less frequent polling poll_interval = rapidata_config.upload.batchPollInterval - last_completed = 0 start_time = time.time() processed_batches: set[str] = set() all_failures: list[FailedUpload[str]] = [] @@ -148,14 +147,14 @@ def _poll_until_complete( processed_batches.add(batch_id) all_failures.extend(failures) + # Update progress bar immediately based on actual processed URLs + if progress_callback: + progress_callback(len(successful_urls) + len(failures)) + # Notify callback with completed URLs if completion_callback and successful_urls: completion_callback(successful_urls) - # Update progress - self._update_progress(status, last_completed, progress_callback) - last_completed = status.completed_count + status.failed_count - # Check completion if status.status == BatchUploadStatus.COMPLETED: elapsed = time.time() - start_time @@ -172,18 +171,6 @@ def _poll_until_complete( logger.error(f"Error polling batch status: {e}") time.sleep(poll_interval) - def _update_progress( - self, - status: GetBatchUploadStatusEndpointOutput, - last_completed: int, - progress_callback: Callable[[int], None] | None, - ) -> None: - """Update progress callback if provided.""" - if progress_callback: - new_completed = status.completed_count + status.failed_count - if new_completed > last_completed: - progress_callback(new_completed - last_completed) - def _process_single_batch(self, batch_id: str) -> tuple[list[str], list[FailedUpload[str]]]: """ Fetch and cache results for a single batch. From b984b8203d5385ddb2a2d54c7093f9b974472104 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:56:29 +0100 Subject: [PATCH 09/21] Revert "moved update of progress bar" This reverts commit 7f72e01e42db3bd7eb2fb06288285c605b15ec37. --- .../datapoints/_batch_asset_uploader.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py index 8d1d2cfa2..11619b6e8 100644 --- a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py @@ -128,6 +128,7 @@ def _poll_until_complete( # More batches = longer expected completion time = less frequent polling poll_interval = rapidata_config.upload.batchPollInterval + last_completed = 0 start_time = time.time() processed_batches: set[str] = set() all_failures: list[FailedUpload[str]] = [] @@ -147,14 +148,14 @@ def _poll_until_complete( processed_batches.add(batch_id) all_failures.extend(failures) - # Update progress bar immediately based on actual processed URLs - if progress_callback: - progress_callback(len(successful_urls) + len(failures)) - # Notify callback with completed URLs if completion_callback and successful_urls: completion_callback(successful_urls) + # Update progress + self._update_progress(status, last_completed, progress_callback) + last_completed = status.completed_count + status.failed_count + # Check completion if status.status == BatchUploadStatus.COMPLETED: elapsed = time.time() - start_time @@ -171,6 +172,18 @@ def _poll_until_complete( logger.error(f"Error polling batch status: {e}") time.sleep(poll_interval) + def _update_progress( + self, + status: GetBatchUploadStatusEndpointOutput, + last_completed: int, + progress_callback: Callable[[int], None] | None, + ) -> None: + """Update progress callback if provided.""" + if progress_callback: + new_completed = status.completed_count + status.failed_count + if new_completed > last_completed: + progress_callback(new_completed - last_completed) + def _process_single_batch(self, batch_id: str) -> tuple[list[str], list[FailedUpload[str]]]: """ Fetch and cache results for a single batch. From 042d4f5badcda2f02dfd861c6c56fa7f799d096d Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 11:37:39 +0100 Subject: [PATCH 10/21] refactored for better readability --- .../rapidata_client/config/upload_config.py | 1 - .../datapoints/_asset_upload_orchestrator.py | 212 ++++++++--- .../datapoints/_asset_uploader.py | 12 +- .../datapoints/_batch_asset_uploader.py | 60 +-- .../dataset/_rapidata_dataset.py | 348 ++++++++++++++---- .../exceptions/failed_upload.py | 29 +- 6 files changed, 498 insertions(+), 164 deletions(-) diff --git a/src/rapidata/rapidata_client/config/upload_config.py b/src/rapidata/rapidata_client/config/upload_config.py index c951526c8..e6644a0c9 100644 --- a/src/rapidata/rapidata_client/config/upload_config.py +++ b/src/rapidata/rapidata_client/config/upload_config.py @@ -21,7 +21,6 @@ class UploadConfig(BaseModel): enableBatchUpload (bool): Enable batch URL uploading (two-step process). Defaults to True. batchSize (int): Number of URLs per batch (10-500). Defaults to 100. batchPollInterval (float): Polling interval in seconds. Defaults to 0.5. - batchPollMaxInterval (float): Maximum polling interval. Defaults to 5.0. batchTimeout (float): Batch upload timeout in seconds. Defaults to 300.0. """ diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py index aa46d8633..2226f3c2f 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -11,9 +11,6 @@ from rapidata.rapidata_client.datapoints._batch_asset_uploader import ( BatchAssetUploader, ) -from rapidata.rapidata_client.exceptions.asset_upload_exception import ( - AssetUploadException, -) from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload if TYPE_CHECKING: @@ -21,6 +18,31 @@ from rapidata.service.openapi_service import OpenAPIService +def extract_assets_from_datapoint(datapoint: Datapoint) -> set[str]: + """ + Extract all assets from a single datapoint. + + Args: + datapoint: The datapoint to extract assets from. + + Returns: + Set of asset identifiers (URLs or file paths). + """ + assets: set[str] = set() + + # Main asset(s) + if isinstance(datapoint.asset, list): + assets.update(datapoint.asset) + else: + assets.add(datapoint.asset) + + # Context asset + if datapoint.media_context: + assets.add(datapoint.media_context) + + return assets + + class AssetUploadOrchestrator: """ Orchestrates Step 1/2: Upload ALL assets from ALL datapoints. @@ -29,10 +51,10 @@ class AssetUploadOrchestrator: filters cached assets, and uploads uncached assets using batch upload for URLs and parallel upload for files. - Raises AssetUploadException if any uploads fail. + Returns list of failed uploads for any assets that fail. """ - def __init__(self, openapi_service: OpenAPIService): + def __init__(self, openapi_service: OpenAPIService) -> None: self.asset_uploader = AssetUploader(openapi_service) self.batch_uploader = BatchAssetUploader(openapi_service) @@ -40,47 +62,102 @@ def upload_all_assets( self, datapoints: list[Datapoint], asset_completion_callback: Callable[[list[str]], None] | None = None, - ) -> None: + ) -> list[FailedUpload[str]]: """ Step 1/2: Upload ALL assets from ALL datapoints. - Throws AssetUploadException if any uploads fail. + Returns list of failed uploads for any assets that fail. Args: datapoints: List of datapoints to extract assets from. asset_completion_callback: Optional callback to notify when assets complete (called with list of successful assets). - Raises: - AssetUploadException: If any asset uploads fail. + Returns: + List of FailedUpload instances for any assets that failed. """ - # 1. Extract all unique assets (deduplicate) + # 1. Extract and validate assets all_assets = self._extract_unique_assets(datapoints) logger.info(f"Extracted {len(all_assets)} unique asset(s) from datapoints") if not all_assets: logger.debug("No assets to upload") - return + return [] + + # 2. Separate and filter assets + urls, files = self._separate_urls_and_files(all_assets) + uncached_urls, uncached_files = self._filter_and_log_cached_assets(urls, files) + + if len(uncached_urls) + len(uncached_files) == 0: + logger.debug("All assets cached, nothing to upload") + return [] + + # 3. Perform uploads + failed_uploads = self._perform_uploads( + uncached_urls, uncached_files, asset_completion_callback + ) - # 2. Separate URLs vs files - urls = [a for a in all_assets if re.match(r"^https?://", a)] - files = [a for a in all_assets if not re.match(r"^https?://", a)] + # 4. Report results + self._log_upload_results(failed_uploads) + return failed_uploads + + def _separate_urls_and_files(self, assets: set[str]) -> tuple[list[str], list[str]]: + """ + Separate assets into URLs and file paths. + + Args: + assets: Set of asset identifiers. + + Returns: + Tuple of (urls, files). + """ + urls = [a for a in assets if re.match(r"^https?://", a)] + files = [a for a in assets if not re.match(r"^https?://", a)] logger.debug(f"Asset breakdown: {len(urls)} URL(s), {len(files)} file(s)") + return urls, files - # 3. Filter cached (skip already-uploaded assets) + def _filter_and_log_cached_assets( + self, urls: list[str], files: list[str] + ) -> tuple[list[str], list[str]]: + """ + Filter out cached assets and log statistics. + + Args: + urls: List of URL assets. + files: List of file assets. + + Returns: + Tuple of (uncached_urls, uncached_files). + """ uncached_urls = self._filter_uncached(urls, self.asset_uploader._url_cache) uncached_files = self._filter_uncached(files, self.asset_uploader._file_cache) + logger.info( f"Assets to upload: {len(uncached_urls)} URL(s), {len(uncached_files)} file(s) " f"(skipped {len(urls) - len(uncached_urls)} cached URL(s), " f"{len(files) - len(uncached_files)} cached file(s))" ) - total = len(uncached_urls) + len(uncached_files) - if total == 0: - logger.debug("All assets cached, nothing to upload") - return + return uncached_urls, uncached_files + + def _perform_uploads( + self, + uncached_urls: list[str], + uncached_files: list[str], + asset_completion_callback: Callable[[list[str]], None] | None, + ) -> list[FailedUpload[str]]: + """ + Execute asset uploads with progress tracking. + + Args: + uncached_urls: URLs to upload. + uncached_files: Files to upload. + asset_completion_callback: Callback for completed assets. - # 4. Upload with single progress bar + Returns: + List of failed uploads. + """ failed_uploads: list[FailedUpload[str]] = [] + total = len(uncached_urls) + len(uncached_files) + with tqdm( total=total, desc="Step 1/2: Uploading assets", @@ -88,59 +165,74 @@ def upload_all_assets( disable=rapidata_config.logging.silent_mode, leave=True, ) as pbar: - # 4a. Batch upload URLs + # Upload URLs if uncached_urls: - logger.debug(f"Batch uploading {len(uncached_urls)} URL(s)") - - def update_progress(n: int) -> None: - pbar.update(n) - - url_failures = self.batch_uploader.batch_upload_urls( - uncached_urls, - progress_callback=update_progress, - completion_callback=asset_completion_callback, + url_failures = self._upload_urls_with_progress( + uncached_urls, pbar, asset_completion_callback ) failed_uploads.extend(url_failures) else: logger.debug("No uncached URLs to upload") - # 4b. Parallel upload files + # Upload files if uncached_files: - logger.debug(f"Parallel uploading {len(uncached_files)} file(s)") - - def update_file_progress() -> None: - pbar.update(1) - - file_failures = self._upload_files_parallel( - uncached_files, - progress_callback=update_file_progress, - completion_callback=asset_completion_callback, + file_failures = self._upload_files_with_progress( + uncached_files, pbar, asset_completion_callback ) failed_uploads.extend(file_failures) else: logger.debug("No uncached files to upload") - # 5. If any failures, throw exception (before Step 2) - if failed_uploads: - logger.error( - f"Asset upload failed for {len(failed_uploads)} asset(s) in Step 1/2" - ) - raise AssetUploadException(failed_uploads) + return failed_uploads + + def _upload_urls_with_progress( + self, + urls: list[str], + pbar: tqdm, + completion_callback: Callable[[list[str]], None] | None, + ) -> list[FailedUpload[str]]: + """Upload URLs with progress bar updates.""" + logger.debug(f"Batch uploading {len(urls)} URL(s)") + + def update_progress(n: int) -> None: + pbar.update(n) + + return self.batch_uploader.batch_upload_urls( + urls, + progress_callback=update_progress, + completion_callback=completion_callback, + ) + + def _upload_files_with_progress( + self, + files: list[str], + pbar: tqdm, + completion_callback: Callable[[list[str]], None] | None, + ) -> list[FailedUpload[str]]: + """Upload files with progress bar updates.""" + logger.debug(f"Parallel uploading {len(files)} file(s)") + + def update_progress() -> None: + pbar.update(1) - logger.info("Step 1/2: All assets uploaded successfully") + return self._upload_files_parallel( + files, + progress_callback=update_progress, + completion_callback=completion_callback, + ) + + def _log_upload_results(self, failed_uploads: list[FailedUpload[str]]) -> None: + """Log the results of asset uploads.""" + if failed_uploads: + logger.warning(f"Step 1/2: {len(failed_uploads)} asset(s) failed to upload") + else: + logger.info("Step 1/2: All assets uploaded successfully") def _extract_unique_assets(self, datapoints: list[Datapoint]) -> set[str]: """Extract all unique assets from all datapoints.""" assets: set[str] = set() for dp in datapoints: - # Main asset(s) - if isinstance(dp.asset, list): - assets.update(dp.asset) - else: - assets.add(dp.asset) - # Context asset - if dp.media_context: - assets.add(dp.media_context) + assets.update(extract_assets_from_datapoint(dp)) return assets def _filter_uncached(self, assets: list[str], cache) -> list[str]: @@ -148,20 +240,18 @@ def _filter_uncached(self, assets: list[str], cache) -> list[str]: uncached = [] for asset in assets: try: - # Try to get cache key + # Try to get cache key using centralized methods if re.match(r"^https?://", asset): - cache_key = ( - f"{self.asset_uploader.openapi_service.environment}@{asset}" - ) + cache_key = self.asset_uploader.get_url_cache_key(asset) else: - cache_key = self.asset_uploader._get_file_cache_key(asset) + cache_key = self.asset_uploader.get_file_cache_key(asset) # Check if in cache if cache_key not in cache._storage: uncached.append(asset) except Exception as e: # If cache check fails, include in upload list - logger.debug(f"Cache check failed for {asset}: {e}") + logger.warning(f"Cache check failed for {asset}: {e}") uncached.append(asset) return uncached diff --git a/src/rapidata/rapidata_client/datapoints/_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_asset_uploader.py index 5be916f79..c9e78de56 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_uploader.py @@ -24,10 +24,10 @@ class AssetUploader: ) _url_cache: SingleFlightCache = SingleFlightCache("URL cache") - def __init__(self, openapi_service: OpenAPIService): + def __init__(self, openapi_service: OpenAPIService) -> None: self.openapi_service = openapi_service - def _get_file_cache_key(self, asset: str) -> str: + def get_file_cache_key(self, asset: str) -> str: """Generate cache key for a file, including environment.""" env = self.openapi_service.environment if not os.path.exists(asset): @@ -36,7 +36,7 @@ def _get_file_cache_key(self, asset: str) -> str: stat = os.stat(asset) return f"{env}@{asset}:{stat.st_size}:{stat.st_mtime_ns}" - def _get_url_cache_key(self, url: str) -> str: + def get_url_cache_key(self, url: str) -> str: """Generate cache key for a URL, including environment.""" env = self.openapi_service.environment return f"{env}@{url}" @@ -55,7 +55,7 @@ def upload_url() -> str: return upload_url() return self._url_cache.get_or_fetch( - self._get_url_cache_key(url), + self.get_url_cache_key(url), upload_url, should_cache=rapidata_config.upload.cacheUploads, ) @@ -76,7 +76,7 @@ def upload_file() -> str: return upload_file() return self._file_cache.get_or_fetch( - self._get_file_cache_key(file_path), + self.get_file_cache_key(file_path), upload_file, should_cache=rapidata_config.upload.cacheUploads, ) @@ -91,7 +91,7 @@ def upload_asset(self, asset: str) -> str: return self._upload_file_asset(asset) - def clear_cache(self): + def clear_cache(self) -> None: self._file_cache.clear() self._url_cache.clear() logger.info("Upload cache cleared") diff --git a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py index 11619b6e8..d7171e847 100644 --- a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py @@ -27,8 +27,9 @@ class BatchAssetUploader: the shared URL cache with successful uploads. """ - def __init__(self, openapi_service: OpenAPIService): + def __init__(self, openapi_service: OpenAPIService) -> None: self.openapi_service = openapi_service + self.asset_uploader = AssetUploader(openapi_service) self.url_cache = AssetUploader._url_cache def batch_upload_urls( @@ -61,7 +62,9 @@ def batch_upload_urls( return self._create_submission_failures(urls) # Poll until complete - return self._poll_until_complete(batch_ids, batch_to_urls, progress_callback, completion_callback) + return self._poll_until_complete( + batch_ids, batch_to_urls, progress_callback, completion_callback + ) def _split_into_batches(self, urls: list[str]) -> list[list[str]]: """Split URLs into batches of configured size.""" @@ -70,7 +73,9 @@ def _split_into_batches(self, urls: list[str]) -> list[list[str]]: logger.info(f"Submitting {len(urls)} URLs in {len(batches)} batch(es)") return batches - def _submit_batches(self, batches: list[list[str]]) -> tuple[list[str], dict[str, list[str]]]: + def _submit_batches( + self, batches: list[list[str]] + ) -> tuple[list[str], dict[str, list[str]]]: """ Submit all batches to the API. @@ -124,8 +129,6 @@ def _poll_until_complete( """ logger.debug(f"Polling {len(batch_ids)} batch(es) for completion") - # Scale initial polling interval based on batch count - # More batches = longer expected completion time = less frequent polling poll_interval = rapidata_config.upload.batchPollInterval last_completed = 0 @@ -144,7 +147,9 @@ def _poll_until_complete( # Process newly completed batches for batch_id in status.completed_batches: if batch_id not in processed_batches: - successful_urls, failures = self._process_single_batch(batch_id) + successful_urls, failures = self._process_single_batch( + batch_id, batch_to_urls + ) processed_batches.add(batch_id) all_failures.extend(failures) @@ -165,7 +170,7 @@ def _poll_until_complete( ) return all_failures - # Wait before next poll (exponential backoff) + # Wait before next poll time.sleep(poll_interval) except Exception as e: @@ -184,12 +189,15 @@ def _update_progress( if new_completed > last_completed: progress_callback(new_completed - last_completed) - def _process_single_batch(self, batch_id: str) -> tuple[list[str], list[FailedUpload[str]]]: + def _process_single_batch( + self, batch_id: str, batch_to_urls: dict[str, list[str]] + ) -> tuple[list[str], list[FailedUpload[str]]]: """ Fetch and cache results for a single batch. Args: batch_id: The batch ID to process. + batch_to_urls: Mapping from batch_id to list of URLs in that batch. Returns: Tuple of (successful_urls, failed_uploads). @@ -207,7 +215,7 @@ def _process_single_batch(self, batch_id: str) -> tuple[list[str], list[FailedUp if item.status == BatchUploadUrlStatus.COMPLETED: # Cache successful upload using proper API if item.file_name is not None: - cache_key = self._get_url_cache_key(item.url) + cache_key = self.asset_uploader.get_url_cache_key(item.url) self.url_cache.set(cache_key, item.file_name) successful_urls.append(item.url) logger.debug( @@ -240,20 +248,33 @@ def _process_single_batch(self, batch_id: str) -> tuple[list[str], list[FailedUp except Exception as e: logger.error(f"Failed to fetch results for batch {batch_id}: {e}") - failed_uploads.append( - FailedUpload( - item=f"batch_{batch_id}", - error_type="BatchResultFetchFailed", - error_message=f"Failed to fetch batch results: {str(e)}", + # Create individual FailedUpload for each URL in the failed batch + if batch_id in batch_to_urls: + for url in batch_to_urls[batch_id]: + failed_uploads.append( + FailedUpload( + item=url, + error_type="BatchResultFetchFailed", + error_message=f"Failed to fetch batch results: {str(e)}", + ) + ) + else: + # Fallback if batch_id not found in mapping + failed_uploads.append( + FailedUpload( + item=f"batch_{batch_id}", + error_type="BatchResultFetchFailed", + error_message=f"Failed to fetch batch results: {str(e)}", + ) ) - ) if successful_urls: - logger.debug(f"Batch {batch_id}: {len(successful_urls)} succeeded, {len(failed_uploads)} failed") + logger.debug( + f"Batch {batch_id}: {len(successful_urls)} succeeded, {len(failed_uploads)} failed" + ) return successful_urls, failed_uploads - def _create_submission_failures(self, urls: list[str]) -> list[FailedUpload[str]]: """Create FailedUpload instances for all URLs when submission fails.""" return [ @@ -264,8 +285,3 @@ def _create_submission_failures(self, urls: list[str]) -> list[FailedUpload[str] ) for url in urls ] - - def _get_url_cache_key(self, url: str) -> str: - """Generate cache key for a URL, including environment.""" - env = self.openapi_service.environment - return f"{env}@{url}" diff --git a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py index 2c2633928..a705b6325 100644 --- a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py +++ b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py @@ -1,7 +1,44 @@ +""" +RapidataDataset module for managing datapoint uploads with incremental asset processing. + +Threading Model: +--------------- +This module uses concurrent processing for both asset uploads and datapoint creation: + +1. Asset Upload Phase (Step 1/2): + - URLs are uploaded in batches with polling for completion + - Files are uploaded in parallel using ThreadPoolExecutor + - Completion callbacks are invoked from worker threads + +2. Datapoint Creation Phase (Step 2/2): + - Datapoints are created incrementally as their required assets complete + - Uses ThreadPoolExecutor with max_workers=rapidata_config.upload.maxWorkers + - Callbacks from asset upload trigger datapoint creation submissions + +Thread-Safety: +------------- +- `lock` (threading.Lock): Protects all shared state during incremental processing + - `datapoint_pending_count`: Maps datapoint index to remaining asset count + - `asset_to_datapoints`: Maps asset to set of datapoint indices waiting for it + - `creation_futures`: List of (idx, Future) tuples for datapoint creation tasks + +Lock Acquisition Order: +---------------------- +1. `on_assets_complete` callback acquires lock to update shared state +2. Lock is released before submitting datapoint creation tasks to avoid blocking +3. Lock is re-acquired briefly to append futures to creation_futures list + +The callback-based design ensures: +- Assets can complete incrementally (batch-by-batch, file-by-file) +- Datapoints are created as soon as all their assets are ready +- No deadlocks occur between asset completion and datapoint submission +""" + from __future__ import annotations import threading from concurrent.futures import ThreadPoolExecutor, Future +from typing import Callable from tqdm.auto import tqdm @@ -10,13 +47,14 @@ from rapidata.rapidata_client.datapoints._datapoint_uploader import DatapointUploader from rapidata.rapidata_client.datapoints._asset_upload_orchestrator import ( AssetUploadOrchestrator, + extract_assets_from_datapoint, ) from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload from rapidata.rapidata_client.config import rapidata_config, logger class RapidataDataset: - def __init__(self, dataset_id: str, openapi_service: OpenAPIService): + def __init__(self, dataset_id: str, openapi_service: OpenAPIService) -> None: self.id = dataset_id self.openapi_service = openapi_service self.datapoint_uploader = DatapointUploader(openapi_service) @@ -37,26 +75,56 @@ def add_datapoints( Returns: tuple[list[Datapoint], list[FailedUpload[Datapoint]]]: Lists of successful uploads and failed uploads with error details - - Raises: - AssetUploadException: If any asset uploads fail and prevent datapoint creation """ if not datapoints: return [], [] - # 1. Build efficient reverse mapping: asset -> datapoint indices that need it - # This allows O(1) lookup instead of checking all pending datapoints + # 1. Build asset-to-datapoint mappings + asset_to_datapoints, datapoint_pending_count = ( + self._build_asset_to_datapoint_mapping(datapoints) + ) + + # 2. Set up shared state for incremental creation + creation_futures: list[tuple[int, Future]] = [] + lock = threading.Lock() + executor = ThreadPoolExecutor(max_workers=rapidata_config.upload.maxWorkers) + + # 3. Execute uploads and incremental datapoint creation + self._execute_incremental_creation( + datapoints, + asset_to_datapoints, + datapoint_pending_count, + creation_futures, + lock, + executor, + ) + + # 4. Collect and return results + return self._collect_and_return_results( + datapoints, creation_futures, datapoint_pending_count, lock + ) + + def _build_asset_to_datapoint_mapping( + self, datapoints: list[Datapoint] + ) -> tuple[dict[str, set[int]], dict[int, int]]: + """ + Build efficient reverse mapping: asset -> datapoint indices that need it. + This allows O(1) lookup instead of checking all pending datapoints. + + Args: + datapoints: List of datapoints to process. + + Returns: + Tuple of (asset_to_datapoints, datapoint_pending_count): + - asset_to_datapoints: Maps asset to set of datapoint indices waiting for it + - datapoint_pending_count: Maps datapoint index to count of remaining assets needed + """ asset_to_datapoints: dict[str, set[int]] = {} - datapoint_pending_count: dict[int, int] = {} # How many assets each datapoint still needs + datapoint_pending_count: dict[int, int] = {} for idx, dp in enumerate(datapoints): - assets = set() - if isinstance(dp.asset, list): - assets.update(dp.asset) - else: - assets.add(dp.asset) - if dp.media_context: - assets.add(dp.media_context) + # Extract all assets for this datapoint using shared utility + assets = extract_assets_from_datapoint(dp) # Track how many assets this datapoint needs datapoint_pending_count[idx] = len(assets) @@ -68,15 +136,29 @@ def add_datapoints( asset_to_datapoints[asset].add(idx) logger.debug(f"Mapped {len(datapoints)} datapoints to their required assets") + return asset_to_datapoints, datapoint_pending_count - # 2. Track state (thread-safe) - creation_futures: list[tuple[int, Future]] = [] - lock = threading.Lock() - - # 3. Create executor for datapoint creation - executor = ThreadPoolExecutor(max_workers=rapidata_config.upload.maxWorkers) + def _execute_incremental_creation( + self, + datapoints: list[Datapoint], + asset_to_datapoints: dict[str, set[int]], + datapoint_pending_count: dict[int, int], + creation_futures: list[tuple[int, Future]], + lock: threading.Lock, + executor: ThreadPoolExecutor, + ) -> None: + """ + Execute asset uploads and incremental datapoint creation. - # 4. Create progress bar for datapoint creation (position=1 to show below asset upload bar) + Args: + datapoints: List of datapoints being processed. + asset_to_datapoints: Mapping from asset to datapoint indices. + datapoint_pending_count: Pending asset count per datapoint. + creation_futures: List to store creation futures. + lock: Lock protecting shared state. + executor: Thread pool executor for datapoint creation. + """ + # Create progress bar for datapoint creation datapoint_pbar = tqdm( total=len(datapoints), desc="Step 2/2: Creating datapoints", @@ -85,67 +167,189 @@ def add_datapoints( leave=True, ) - # 5. Define callback for asset completion + try: + # Create callback that submits datapoints for creation + on_assets_complete = self._create_asset_completion_callback( + datapoints, + asset_to_datapoints, + datapoint_pending_count, + creation_futures, + lock, + executor, + datapoint_pbar, + ) + + # Start uploads (blocking, but triggers callbacks as assets complete) + logger.info("Starting incremental datapoint creation") + asset_failures = self.asset_orchestrator.upload_all_assets( + datapoints, asset_completion_callback=on_assets_complete + ) + + if asset_failures: + logger.warning( + f"{len(asset_failures)} asset(s) failed to upload, affected datapoints will be marked as failed" + ) + + # Wait for all datapoint creation to complete + executor.shutdown(wait=True) + logger.debug("All datapoint creation tasks completed") + finally: + # Always close progress bar, even on exception + datapoint_pbar.close() + + def _create_asset_completion_callback( + self, + datapoints: list[Datapoint], + asset_to_datapoints: dict[str, set[int]], + datapoint_pending_count: dict[int, int], + creation_futures: list[tuple[int, Future]], + lock: threading.Lock, + executor: ThreadPoolExecutor, + datapoint_pbar: tqdm, + ) -> Callable[[list[str]], None]: + """ + Create callback function that handles asset completion. + + THREAD-SAFETY: The returned callback is invoked from worker threads during asset upload. + All access to shared state is protected by the lock. + + Args: + datapoints: List of datapoints being processed. + asset_to_datapoints: Mapping from asset to datapoint indices. + datapoint_pending_count: Pending asset count per datapoint. + creation_futures: List to store creation futures. + lock: Lock protecting shared state. + executor: Thread pool executor for datapoint creation. + datapoint_pbar: Progress bar for datapoint creation. + + Returns: + Callback function to be invoked when assets complete. + """ + def on_assets_complete(assets: list[str]) -> None: """Called when a batch of assets completes uploading.""" - ready_datapoints = [] + ready_datapoints = self._find_ready_datapoints( + assets, asset_to_datapoints, datapoint_pending_count, lock + ) + + # Submit ready datapoints for creation (outside lock to avoid blocking) + self._submit_datapoints_for_creation( + ready_datapoints, + datapoints, + creation_futures, + lock, + executor, + datapoint_pbar, + ) + + return on_assets_complete + + def _find_ready_datapoints( + self, + assets: list[str], + asset_to_datapoints: dict[str, set[int]], + datapoint_pending_count: dict[int, int], + lock: threading.Lock, + ) -> list[int]: + """ + Find datapoints that are ready for creation after asset completion. + + Args: + assets: List of completed assets. + asset_to_datapoints: Mapping from asset to datapoint indices. + datapoint_pending_count: Pending asset count per datapoint. + lock: Lock protecting shared state. + + Returns: + List of datapoint indices ready for creation. + """ + ready_datapoints = [] + + with lock: + # For each completed asset, find datapoints that need it + for asset in assets: + if asset in asset_to_datapoints: + # Get all datapoints waiting for this asset + for idx in list(asset_to_datapoints[asset]): + if idx in datapoint_pending_count: + # Decrement the count + datapoint_pending_count[idx] -= 1 + + # If all assets are ready, mark for creation + if datapoint_pending_count[idx] == 0: + ready_datapoints.append(idx) + del datapoint_pending_count[idx] + + # Remove this datapoint from this asset's waiting list + asset_to_datapoints[asset].discard(idx) + + return ready_datapoints + def _submit_datapoints_for_creation( + self, + ready_datapoints: list[int], + datapoints: list[Datapoint], + creation_futures: list[tuple[int, Future]], + lock: threading.Lock, + executor: ThreadPoolExecutor, + datapoint_pbar: tqdm, + ) -> None: + """ + Submit ready datapoints for creation. + + Args: + ready_datapoints: Indices of datapoints ready for creation. + datapoints: List of all datapoints. + creation_futures: List to store creation futures. + lock: Lock protecting creation_futures. + executor: Thread pool executor for datapoint creation. + datapoint_pbar: Progress bar for datapoint creation. + """ + for idx in ready_datapoints: + + def upload_and_update(dp_idx): + """Upload datapoint and update progress bar when done.""" + try: + self.datapoint_uploader.upload_datapoint( + dataset_id=self.id, + datapoint=datapoints[dp_idx], + index=dp_idx, + ) + finally: + datapoint_pbar.update(1) + + future = executor.submit(upload_and_update, idx) with lock: - # For each completed asset, find datapoints that need it - for asset in assets: - if asset in asset_to_datapoints: - # Get all datapoints waiting for this asset - for idx in list(asset_to_datapoints[asset]): - if idx in datapoint_pending_count: - # Decrement the count - datapoint_pending_count[idx] -= 1 - - # If all assets are ready, mark for creation - if datapoint_pending_count[idx] == 0: - ready_datapoints.append(idx) - del datapoint_pending_count[idx] - - # Remove this datapoint from this asset's waiting list - asset_to_datapoints[asset].discard(idx) + creation_futures.append((idx, future)) - # Submit ready datapoints for creation (outside lock to avoid blocking) - for idx in ready_datapoints: - - def upload_and_update(dp_idx): - """Upload datapoint and update progress bar when done.""" - try: - self.datapoint_uploader.upload_datapoint( - dataset_id=self.id, - datapoint=datapoints[dp_idx], - index=dp_idx, - ) - finally: - datapoint_pbar.update(1) - - future = executor.submit(upload_and_update, idx) - with lock: - creation_futures.append((idx, future)) - - if ready_datapoints: - logger.debug( - f"Asset batch completed, {len(ready_datapoints)} datapoints now ready for creation" - ) + if ready_datapoints: + logger.debug( + f"Asset batch completed, {len(ready_datapoints)} datapoints now ready for creation" + ) - # 6. Start uploads (blocking, but triggers callbacks as assets complete) - logger.info("Starting incremental datapoint creation") - self.asset_orchestrator.upload_all_assets( - datapoints, asset_completion_callback=on_assets_complete - ) + def _collect_and_return_results( + self, + datapoints: list[Datapoint], + creation_futures: list[tuple[int, Future]], + datapoint_pending_count: dict[int, int], + lock: threading.Lock, + ) -> tuple[list[Datapoint], list[FailedUpload[Datapoint]]]: + """ + Collect results from datapoint creation tasks. - # 7. Wait for all datapoint creation to complete - executor.shutdown(wait=True) - datapoint_pbar.close() - logger.debug("All datapoint creation tasks completed") + Args: + datapoints: List of all datapoints. + creation_futures: List of creation futures. + datapoint_pending_count: Datapoints whose assets failed. + lock: Lock protecting datapoint_pending_count. - # 8. Collect results + Returns: + Tuple of (successful_uploads, failed_uploads). + """ successful_uploads: list[Datapoint] = [] failed_uploads: list[FailedUpload[Datapoint]] = [] + # Collect results from creation tasks for idx, future in creation_futures: try: future.result() # Raises exception if failed @@ -160,7 +364,7 @@ def upload_and_update(dp_idx): ) ) - # 9. Handle datapoints whose assets failed to upload + # Handle datapoints whose assets failed to upload with lock: for idx in datapoint_pending_count: logger.warning(f"Datapoint {idx} assets failed to upload") diff --git a/src/rapidata/rapidata_client/exceptions/failed_upload.py b/src/rapidata/rapidata_client/exceptions/failed_upload.py index 1a35e65b5..29b7ff910 100644 --- a/src/rapidata/rapidata_client/exceptions/failed_upload.py +++ b/src/rapidata/rapidata_client/exceptions/failed_upload.py @@ -1,6 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import TypeVar, Generic, TYPE_CHECKING +from dataclasses import dataclass, field +from datetime import datetime +from typing import TypeVar, Generic, TYPE_CHECKING, Optional if TYPE_CHECKING: from rapidata.rapidata_client.exceptions.rapidata_error import RapidataError @@ -17,11 +18,15 @@ class FailedUpload(Generic[T]): item: The item that failed to upload. error_message: The error message describing the failure reason. error_type: The type of the exception (e.g., "RapidataError"). + timestamp: Optional timestamp when the failure occurred. + exception: Optional original exception for richer error context. """ item: T error_message: str error_type: str + timestamp: Optional[datetime] = field(default_factory=datetime.now) + exception: Optional[Exception] = None @classmethod def from_exception(cls, item: T, exception: Exception | None) -> FailedUpload[T]: @@ -43,6 +48,7 @@ def from_exception(cls, item: T, exception: Exception | None) -> FailedUpload[T] item=item, error_message="Unknown error", error_type="Unknown", + exception=None, ) from rapidata.rapidata_client.exceptions.rapidata_error import RapidataError @@ -58,7 +64,26 @@ def from_exception(cls, item: T, exception: Exception | None) -> FailedUpload[T] item=item, error_message=error_message, error_type=error_type, + exception=exception, ) + def format_error_details(self) -> str: + """ + Format error details for logging or display. + + Returns: + Formatted string with all error details including timestamp. + """ + details = [ + f"Item: {self.item}", + f"Error Type: {self.error_type}", + f"Error Message: {self.error_message}", + ] + + if self.timestamp: + details.append(f"Timestamp: {self.timestamp.isoformat()}") + + return "\n".join(details) + def __str__(self) -> str: return f"{self.item}" From 440646462dda54fe5d2e8b37404c05b1b8f0a416 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 14:29:28 +0100 Subject: [PATCH 11/21] moved ability to disable cache to disable cacheToDisk --- .../rapidata_client/config/upload_config.py | 13 ++- .../datapoints/_asset_upload_orchestrator.py | 4 +- .../datapoints/_asset_uploader.py | 85 +++++++++++++------ 3 files changed, 70 insertions(+), 32 deletions(-) diff --git a/src/rapidata/rapidata_client/config/upload_config.py b/src/rapidata/rapidata_client/config/upload_config.py index e6644a0c9..b2900d17e 100644 --- a/src/rapidata/rapidata_client/config/upload_config.py +++ b/src/rapidata/rapidata_client/config/upload_config.py @@ -11,13 +11,15 @@ class UploadConfig(BaseModel): Attributes: maxWorkers (int): The maximum number of worker threads for concurrent uploads. Defaults to 25. maxRetries (int): The maximum number of retries for failed uploads. Defaults to 3. - cacheUploads (bool): Enable/disable upload caching. Defaults to True. + cacheToDisk (bool): Enable disk-based caching for file uploads. If False, uses in-memory cache only. Defaults to True. + Note: URL assets are always cached in-memory regardless of this setting. + Caching cannot be disabled entirely as it's required for the two-step upload flow. cacheTimeout (float): Cache operation timeout in seconds. Defaults to 0.1. cacheLocation (Path): Directory for cache storage. Defaults to ~/.cache/rapidata/upload_cache. - This is immutable + This is immutable. Only used for file uploads when cacheToDisk=True. cacheShards (int): Number of cache shards for parallel access. Defaults to 128. Higher values improve concurrency but increase file handles. Must be positive. - This is immutable + This is immutable. Only used for file uploads when cacheToDisk=True. enableBatchUpload (bool): Enable batch URL uploading (two-step process). Defaults to True. batchSize (int): Number of URLs per batch (10-500). Defaults to 100. batchPollInterval (float): Polling interval in seconds. Defaults to 0.5. @@ -28,7 +30,10 @@ class UploadConfig(BaseModel): maxWorkers: int = Field(default=25) maxRetries: int = Field(default=3) - cacheUploads: bool = Field(default=True) + cacheToDisk: bool = Field( + default=True, + description="Enable disk-based caching for file uploads. URLs are always cached in-memory.", + ) cacheTimeout: float = Field(default=0.1) cacheLocation: Path = Field( default=Path.home() / ".cache" / "rapidata" / "upload_cache", diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py index 2226f3c2f..f4733a94b 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -128,7 +128,9 @@ def _filter_and_log_cached_assets( Tuple of (uncached_urls, uncached_files). """ uncached_urls = self._filter_uncached(urls, self.asset_uploader._url_cache) - uncached_files = self._filter_uncached(files, self.asset_uploader._file_cache) + uncached_files = self._filter_uncached( + files, self.asset_uploader._get_file_cache() + ) logger.info( f"Assets to upload: {len(uncached_urls)} URL(s), {len(uncached_files)} file(s) " diff --git a/src/rapidata/rapidata_client/datapoints/_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_asset_uploader.py index c9e78de56..d24e35d90 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_uploader.py @@ -2,6 +2,7 @@ import re import os +import threading from typing import TYPE_CHECKING from rapidata.service.openapi_service import OpenAPIService @@ -14,15 +15,46 @@ class AssetUploader: - _file_cache: SingleFlightCache = SingleFlightCache( - "File cache", - storage=FanoutCache( - rapidata_config.upload.cacheLocation, - shards=rapidata_config.upload.cacheShards, - timeout=rapidata_config.upload.cacheTimeout, - ), - ) - _url_cache: SingleFlightCache = SingleFlightCache("URL cache") + # Class-level caches shared across all instances + # URL cache: Always in-memory (URLs are lightweight, no benefit to disk caching) + # File cache: Lazily initialized based on cacheToDisk config + _url_cache: SingleFlightCache = SingleFlightCache("URL cache", storage={}) + _file_cache: SingleFlightCache | None = None + _file_cache_lock: threading.Lock = threading.Lock() + + @classmethod + def _get_file_cache(cls) -> SingleFlightCache: + """ + Get or create the file cache based on current config. + + Uses lazy initialization to respect cacheToDisk setting at runtime. + Thread-safe with double-checked locking pattern. + + Returns: + Configured file cache (disk or memory based on cacheToDisk). + """ + if cls._file_cache is not None: + return cls._file_cache + + with cls._file_cache_lock: + # Double-check after acquiring lock + if cls._file_cache is not None: + return cls._file_cache + + # Create cache storage based on current config + if rapidata_config.upload.cacheToDisk: + storage: dict[str, str] | FanoutCache = FanoutCache( + rapidata_config.upload.cacheLocation, + shards=rapidata_config.upload.cacheShards, + timeout=rapidata_config.upload.cacheTimeout, + ) + logger.debug("Initialized file cache with disk storage") + else: + storage = {} + logger.debug("Initialized file cache with in-memory storage") + + cls._file_cache = SingleFlightCache("File cache", storage=storage) + return cls._file_cache def __init__(self, openapi_service: OpenAPIService) -> None: self.openapi_service = openapi_service @@ -42,7 +74,12 @@ def get_url_cache_key(self, url: str) -> str: return f"{env}@{url}" def _upload_url_asset(self, url: str) -> str: - """Upload a URL asset, with optional caching.""" + """ + Upload a URL asset with caching. + + URLs are always cached in-memory (lightweight, no disk I/O overhead). + Caching is required for the two-step upload flow and cannot be disabled. + """ def upload_url() -> str: response = self.openapi_service.asset_api.asset_url_post(url=url) @@ -51,17 +88,15 @@ def upload_url() -> str: ) return response.file_name - if not rapidata_config.upload.cacheUploads: - return upload_url() - - return self._url_cache.get_or_fetch( - self.get_url_cache_key(url), - upload_url, - should_cache=rapidata_config.upload.cacheUploads, - ) + return self._url_cache.get_or_fetch(self.get_url_cache_key(url), upload_url) def _upload_file_asset(self, file_path: str) -> str: - """Upload a local file asset, with optional caching.""" + """ + Upload a local file asset with caching. + + Caching is always enabled as it's required for the two-step upload flow. + Use cacheToDisk config to control whether cache is stored to disk or memory. + """ def upload_file() -> str: response = self.openapi_service.asset_api.asset_file_post(file=file_path) @@ -72,13 +107,8 @@ def upload_file() -> str: ) return response.file_name - if not rapidata_config.upload.cacheUploads: - return upload_file() - - return self._file_cache.get_or_fetch( - self.get_file_cache_key(file_path), - upload_file, - should_cache=rapidata_config.upload.cacheUploads, + return self._get_file_cache().get_or_fetch( + self.get_file_cache_key(file_path), upload_file ) def upload_asset(self, asset: str) -> str: @@ -92,7 +122,8 @@ def upload_asset(self, asset: str) -> str: return self._upload_file_asset(asset) def clear_cache(self) -> None: - self._file_cache.clear() + """Clear both URL and file caches.""" + self._get_file_cache().clear() self._url_cache.clear() logger.info("Upload cache cleared") From 58740defd1ca6fa74d97121986c320f9bce4dc77 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 14:48:03 +0100 Subject: [PATCH 12/21] create failed upload from exception instead of custom where possible --- .../datapoints/_batch_asset_uploader.py | 15 +++------------ .../rapidata_client/dataset/_rapidata_dataset.py | 9 ++------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py index d7171e847..f3809ebea 100644 --- a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py @@ -249,23 +249,14 @@ def _process_single_batch( except Exception as e: logger.error(f"Failed to fetch results for batch {batch_id}: {e}") # Create individual FailedUpload for each URL in the failed batch + # Use from_exception to extract proper error reason from RapidataError if batch_id in batch_to_urls: for url in batch_to_urls[batch_id]: - failed_uploads.append( - FailedUpload( - item=url, - error_type="BatchResultFetchFailed", - error_message=f"Failed to fetch batch results: {str(e)}", - ) - ) + failed_uploads.append(FailedUpload.from_exception(url, e)) else: # Fallback if batch_id not found in mapping failed_uploads.append( - FailedUpload( - item=f"batch_{batch_id}", - error_type="BatchResultFetchFailed", - error_message=f"Failed to fetch batch results: {str(e)}", - ) + FailedUpload.from_exception(f"batch_{batch_id}", e) ) if successful_urls: diff --git a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py index a705b6325..ebbc0d944 100644 --- a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py +++ b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py @@ -356,13 +356,8 @@ def _collect_and_return_results( successful_uploads.append(datapoints[idx]) except Exception as e: logger.warning(f"Failed to create datapoint {idx}: {e}") - failed_uploads.append( - FailedUpload( - item=datapoints[idx], - error_type="DatapointCreationFailed", - error_message=str(e), - ) - ) + # Use from_exception to extract proper error reason from RapidataError + failed_uploads.append(FailedUpload.from_exception(datapoints[idx], e)) # Handle datapoints whose assets failed to upload with lock: From a97cbec772f0393b9579eb3bff8ca360f0424934 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:21:05 +0100 Subject: [PATCH 13/21] removed unused timeout --- src/rapidata/rapidata_client/config/upload_config.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/rapidata/rapidata_client/config/upload_config.py b/src/rapidata/rapidata_client/config/upload_config.py index b2900d17e..f26880019 100644 --- a/src/rapidata/rapidata_client/config/upload_config.py +++ b/src/rapidata/rapidata_client/config/upload_config.py @@ -23,7 +23,6 @@ class UploadConfig(BaseModel): enableBatchUpload (bool): Enable batch URL uploading (two-step process). Defaults to True. batchSize (int): Number of URLs per batch (10-500). Defaults to 100. batchPollInterval (float): Polling interval in seconds. Defaults to 0.5. - batchTimeout (float): Batch upload timeout in seconds. Defaults to 300.0. """ model_config = ConfigDict(validate_assignment=True) @@ -51,10 +50,6 @@ class UploadConfig(BaseModel): default=0.5, description="Polling interval in seconds", ) - batchTimeout: float = Field( - default=300.0, - description="Batch upload timeout in seconds", - ) @field_validator("maxWorkers") @classmethod From aea89810eb2bd6b993402654d834121c3882c95a Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 16:42:54 +0100 Subject: [PATCH 14/21] adjusted input in orchestrator and added better doc strings for index usage --- .../datapoints/_asset_upload_orchestrator.py | 19 ++--- .../dataset/_rapidata_dataset.py | 69 +++++++++++-------- 2 files changed, 47 insertions(+), 41 deletions(-) diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py index f4733a94b..a211d368d 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -60,23 +60,23 @@ def __init__(self, openapi_service: OpenAPIService) -> None: def upload_all_assets( self, - datapoints: list[Datapoint], + assets: set[str] | list[str], asset_completion_callback: Callable[[list[str]], None] | None = None, ) -> list[FailedUpload[str]]: """ - Step 1/2: Upload ALL assets from ALL datapoints. + Step 1/2: Upload ALL assets. Returns list of failed uploads for any assets that fail. Args: - datapoints: List of datapoints to extract assets from. + assets: Set or list of asset identifiers (URLs or file paths) to upload. asset_completion_callback: Optional callback to notify when assets complete (called with list of successful assets). Returns: List of FailedUpload instances for any assets that failed. """ - # 1. Extract and validate assets - all_assets = self._extract_unique_assets(datapoints) - logger.info(f"Extracted {len(all_assets)} unique asset(s) from datapoints") + # 1. Validate assets + all_assets = set(assets) if isinstance(assets, list) else assets + logger.info(f"Uploading {len(all_assets)} unique asset(s)") if not all_assets: logger.debug("No assets to upload") @@ -230,13 +230,6 @@ def _log_upload_results(self, failed_uploads: list[FailedUpload[str]]) -> None: else: logger.info("Step 1/2: All assets uploaded successfully") - def _extract_unique_assets(self, datapoints: list[Datapoint]) -> set[str]: - """Extract all unique assets from all datapoints.""" - assets: set[str] = set() - for dp in datapoints: - assets.update(extract_assets_from_datapoint(dp)) - return assets - def _filter_uncached(self, assets: list[str], cache) -> list[str]: """Filter out assets that are already cached.""" uncached = [] diff --git a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py index ebbc0d944..c6cfb9ef7 100644 --- a/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py +++ b/src/rapidata/rapidata_client/dataset/_rapidata_dataset.py @@ -111,14 +111,21 @@ def _build_asset_to_datapoint_mapping( Build efficient reverse mapping: asset -> datapoint indices that need it. This allows O(1) lookup instead of checking all pending datapoints. + Note on using indices: We use integer indices instead of datapoint objects because: + - Indices are hashable (required for dict keys and sets) + - Lightweight and fast to compare + - Provides stable references to datapoints in the list + - The datapoint objects remain in a single location (the datapoints list) + Args: - datapoints: List of datapoints to process. + datapoints: List of datapoints to process. Indices into this list are used as identifiers. Returns: Tuple of (asset_to_datapoints, datapoint_pending_count): - - asset_to_datapoints: Maps asset to set of datapoint indices waiting for it - - datapoint_pending_count: Maps datapoint index to count of remaining assets needed + - asset_to_datapoints: Maps each asset to set of datapoint indices (positions in datapoints list) that need it + - datapoint_pending_count: Maps each datapoint index to count of remaining assets it needs """ + # Using indices instead of datapoints directly for hashability and performance asset_to_datapoints: dict[str, set[int]] = {} datapoint_pending_count: dict[int, int] = {} @@ -179,10 +186,13 @@ def _execute_incremental_creation( datapoint_pbar, ) + # Extract all unique assets from the mapping + all_assets = set(asset_to_datapoints.keys()) + # Start uploads (blocking, but triggers callbacks as assets complete) logger.info("Starting incremental datapoint creation") asset_failures = self.asset_orchestrator.upload_all_assets( - datapoints, asset_completion_callback=on_assets_complete + all_assets, asset_completion_callback=on_assets_complete ) if asset_failures: @@ -228,13 +238,13 @@ def _create_asset_completion_callback( def on_assets_complete(assets: list[str]) -> None: """Called when a batch of assets completes uploading.""" - ready_datapoints = self._find_ready_datapoints( + ready_datapoint_indices = self._find_ready_datapoints( assets, asset_to_datapoints, datapoint_pending_count, lock ) # Submit ready datapoints for creation (outside lock to avoid blocking) self._submit_datapoints_for_creation( - ready_datapoints, + ready_datapoint_indices, datapoints, creation_futures, lock, @@ -254,40 +264,43 @@ def _find_ready_datapoints( """ Find datapoints that are ready for creation after asset completion. + Returns indices into the original datapoints list for datapoints whose + assets are now all complete. + Args: assets: List of completed assets. asset_to_datapoints: Mapping from asset to datapoint indices. - datapoint_pending_count: Pending asset count per datapoint. + datapoint_pending_count: Pending asset count per datapoint index. lock: Lock protecting shared state. Returns: - List of datapoint indices ready for creation. + List of datapoint indices (positions in the datapoints list) ready for creation. """ - ready_datapoints = [] + ready_datapoint_indices = [] with lock: # For each completed asset, find datapoints that need it for asset in assets: if asset in asset_to_datapoints: - # Get all datapoints waiting for this asset - for idx in list(asset_to_datapoints[asset]): - if idx in datapoint_pending_count: - # Decrement the count - datapoint_pending_count[idx] -= 1 + # Get all datapoint indices waiting for this asset + for datapoint_idx in list(asset_to_datapoints[asset]): + if datapoint_idx in datapoint_pending_count: + # Decrement the remaining asset count for this datapoint + datapoint_pending_count[datapoint_idx] -= 1 - # If all assets are ready, mark for creation - if datapoint_pending_count[idx] == 0: - ready_datapoints.append(idx) - del datapoint_pending_count[idx] + # If all assets are now ready, mark this datapoint for creation + if datapoint_pending_count[datapoint_idx] == 0: + ready_datapoint_indices.append(datapoint_idx) + del datapoint_pending_count[datapoint_idx] # Remove this datapoint from this asset's waiting list - asset_to_datapoints[asset].discard(idx) + asset_to_datapoints[asset].discard(datapoint_idx) - return ready_datapoints + return ready_datapoint_indices def _submit_datapoints_for_creation( self, - ready_datapoints: list[int], + ready_datapoint_indices: list[int], datapoints: list[Datapoint], creation_futures: list[tuple[int, Future]], lock: threading.Lock, @@ -298,14 +311,14 @@ def _submit_datapoints_for_creation( Submit ready datapoints for creation. Args: - ready_datapoints: Indices of datapoints ready for creation. - datapoints: List of all datapoints. + ready_datapoint_indices: Indices (positions in datapoints list) of datapoints ready for creation. + datapoints: List of all datapoints (indexed by ready_datapoint_indices). creation_futures: List to store creation futures. lock: Lock protecting creation_futures. executor: Thread pool executor for datapoint creation. datapoint_pbar: Progress bar for datapoint creation. """ - for idx in ready_datapoints: + for datapoint_idx in ready_datapoint_indices: def upload_and_update(dp_idx): """Upload datapoint and update progress bar when done.""" @@ -318,13 +331,13 @@ def upload_and_update(dp_idx): finally: datapoint_pbar.update(1) - future = executor.submit(upload_and_update, idx) + future = executor.submit(upload_and_update, datapoint_idx) with lock: - creation_futures.append((idx, future)) + creation_futures.append((datapoint_idx, future)) - if ready_datapoints: + if ready_datapoint_indices: logger.debug( - f"Asset batch completed, {len(ready_datapoints)} datapoints now ready for creation" + f"Asset batch completed, {len(ready_datapoint_indices)} datapoints now ready for creation" ) def _collect_and_return_results( From 67fb6abe23e1798f79549995f73f0500bdadf6c1 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Wed, 28 Jan 2026 16:54:47 +0100 Subject: [PATCH 15/21] upped batchsize to 1000 --- src/rapidata/rapidata_client/config/upload_config.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/rapidata/rapidata_client/config/upload_config.py b/src/rapidata/rapidata_client/config/upload_config.py index f26880019..ff9ee08a4 100644 --- a/src/rapidata/rapidata_client/config/upload_config.py +++ b/src/rapidata/rapidata_client/config/upload_config.py @@ -43,8 +43,8 @@ class UploadConfig(BaseModel): frozen=True, ) batchSize: int = Field( - default=100, - description="Number of URLs per batch (10-500)", + default=1000, + description="Number of URLs per batch (100-5000)", ) batchPollInterval: float = Field( default=0.5, @@ -75,10 +75,8 @@ def validate_cache_shards(cls, v: int) -> int: @field_validator("batchSize") @classmethod def validate_batch_size(cls, v: int) -> int: - if v < 10: - raise ValueError("batchSize must be at least 10") - if v > 500: - logger.warning(f"batchSize={v} may cause timeouts. Recommend 50-200.") + if v < 100: + raise ValueError("batchSize must be at least 100") return v def __init__(self, **kwargs): From c96f895e2f0f6cb74fbb972644ce24f4b7feee3d Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:45:39 +0100 Subject: [PATCH 16/21] upped cacheTimout --- src/rapidata/rapidata_client/config/upload_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rapidata/rapidata_client/config/upload_config.py b/src/rapidata/rapidata_client/config/upload_config.py index ff9ee08a4..1b46e55b0 100644 --- a/src/rapidata/rapidata_client/config/upload_config.py +++ b/src/rapidata/rapidata_client/config/upload_config.py @@ -33,7 +33,7 @@ class UploadConfig(BaseModel): default=True, description="Enable disk-based caching for file uploads. URLs are always cached in-memory.", ) - cacheTimeout: float = Field(default=0.1) + cacheTimeout: float = Field(default=1) cacheLocation: Path = Field( default=Path.home() / ".cache" / "rapidata" / "upload_cache", frozen=True, From b5210b2432e6cab00a132dc64c29e513f2a5f7a8 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:04:32 +0100 Subject: [PATCH 17/21] simultaneous batch creation and pulling --- .../datapoints/_batch_asset_uploader.py | 135 ++++++++++++------ 1 file changed, 88 insertions(+), 47 deletions(-) diff --git a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py index f3809ebea..03f0447da 100644 --- a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py @@ -1,6 +1,7 @@ from __future__ import annotations import time +import threading from typing import Callable, TYPE_CHECKING from rapidata.rapidata_client.config import logger, rapidata_config @@ -42,6 +43,9 @@ def batch_upload_urls( Upload URLs in batches. Returns list of failed uploads. Successful uploads are cached automatically. + Batches are submitted concurrently with polling - polling starts as soon + as the first batch is submitted, allowing progress to be visible immediately. + Args: urls: List of URLs to upload. progress_callback: Optional callback to report progress (called with number of newly completed items). @@ -53,17 +57,67 @@ def batch_upload_urls( if not urls: return [] - # Split and submit batches + # Split into batches batches = self._split_into_batches(urls) - batch_ids, batch_to_urls = self._submit_batches(batches) - if not batch_ids: + # Thread-safe collections for concurrent submission and polling + batch_ids_lock = threading.Lock() + batch_ids: list[str] = [] + batch_to_urls: dict[str, list[str]] = {} + submission_complete = threading.Event() + + # Submit batches in background thread + def submit_batches_background(): + """Submit all batches and signal completion.""" + for batch_idx, batch in enumerate(batches): + try: + result = self.openapi_service.batch_upload_api.asset_batch_upload_post( + create_batch_upload_endpoint_input=CreateBatchUploadEndpointInput( + urls=batch + ) + ) + batch_id = result.batch_upload_id + + # Add to shared collections (thread-safe) + with batch_ids_lock: + batch_ids.append(batch_id) + batch_to_urls[batch_id] = batch + + logger.debug( + f"Submitted batch {batch_idx + 1}/{len(batches)}: {batch_id}" + ) + except Exception as e: + logger.error(f"Failed to submit batch {batch_idx + 1}: {e}") + + # Signal that all batches have been submitted + submission_complete.set() + with batch_ids_lock: + logger.info( + f"Successfully submitted {len(batch_ids)}/{len(batches)} batches" + ) + + # Start background submission + submission_thread = threading.Thread( + target=submit_batches_background, daemon=True + ) + submission_thread.start() + + # Wait for at least one batch to be submitted before starting poll + while len(batch_ids) == 0 and not submission_complete.is_set(): + time.sleep(0.5) + + if len(batch_ids) == 0: logger.error("No batches were successfully submitted") return self._create_submission_failures(urls) - # Poll until complete + # Poll until complete (will handle dynamically growing batch list) return self._poll_until_complete( - batch_ids, batch_to_urls, progress_callback, completion_callback + batch_ids, + batch_to_urls, + batch_ids_lock, + submission_complete, + progress_callback, + completion_callback, ) def _split_into_batches(self, urls: list[str]) -> list[list[str]]: @@ -73,62 +127,32 @@ def _split_into_batches(self, urls: list[str]) -> list[list[str]]: logger.info(f"Submitting {len(urls)} URLs in {len(batches)} batch(es)") return batches - def _submit_batches( - self, batches: list[list[str]] - ) -> tuple[list[str], dict[str, list[str]]]: - """ - Submit all batches to the API. - - Args: - batches: List of URL batches to submit. - - Returns: - Tuple of (batch_ids, batch_to_urls) where batch_to_urls maps batch_id to its URL list. - """ - batch_ids: list[str] = [] - batch_to_urls: dict[str, list[str]] = {} - - for batch_idx, batch in enumerate(batches): - try: - result = self.openapi_service.batch_upload_api.asset_batch_upload_post( - create_batch_upload_endpoint_input=CreateBatchUploadEndpointInput( - urls=batch - ) - ) - batch_id = result.batch_upload_id - batch_ids.append(batch_id) - batch_to_urls[batch_id] = batch - logger.debug( - f"Submitted batch {batch_idx + 1}/{len(batches)}: {batch_id}" - ) - except Exception as e: - logger.error(f"Failed to submit batch {batch_idx + 1}: {e}") - # Continue trying to submit remaining batches - - logger.info(f"Successfully submitted {len(batch_ids)}/{len(batches)} batches") - return batch_ids, batch_to_urls - def _poll_until_complete( self, batch_ids: list[str], batch_to_urls: dict[str, list[str]], + batch_ids_lock: threading.Lock, + submission_complete: threading.Event, progress_callback: Callable[[int], None] | None, completion_callback: Callable[[list[str]], None] | None, ) -> list[FailedUpload[str]]: """ Poll batches until all complete. Process batches incrementally as they complete. + Supports concurrent batch submission - will poll currently submitted batches + and continue until all batches are submitted and completed. + Args: - batch_ids: List of batch IDs to poll. - batch_to_urls: Mapping from batch_id to list of URLs in that batch. + batch_ids: Shared list of batch IDs (grows as batches are submitted). + batch_to_urls: Shared mapping from batch_id to list of URLs in that batch. + batch_ids_lock: Lock protecting batch_ids and batch_to_urls. + submission_complete: Event signaling all batches have been submitted. progress_callback: Optional callback to report progress. completion_callback: Optional callback to notify when URLs complete. Returns: List of FailedUpload instances for any URLs that failed. """ - logger.debug(f"Polling {len(batch_ids)} batch(es) for completion") - poll_interval = rapidata_config.upload.batchPollInterval last_completed = 0 @@ -137,10 +161,23 @@ def _poll_until_complete( all_failures: list[FailedUpload[str]] = [] while True: + # Get current batch IDs (thread-safe) + with batch_ids_lock: + current_batch_ids = batch_ids.copy() + total_batches_submitted = len(current_batch_ids) + + if not current_batch_ids: + # No batches yet, wait a bit + time.sleep(poll_interval) + continue + + logger.debug( + f"Polling {total_batches_submitted} submitted batch(es) for completion" + ) try: status = ( self.openapi_service.batch_upload_api.asset_batch_upload_status_get( - batch_upload_ids=batch_ids + batch_upload_ids=current_batch_ids ) ) @@ -161,8 +198,12 @@ def _poll_until_complete( self._update_progress(status, last_completed, progress_callback) last_completed = status.completed_count + status.failed_count - # Check completion - if status.status == BatchUploadStatus.COMPLETED: + # Check if we're done: + # 1. All batches have been submitted + # 2. All submitted batches have been processed + if submission_complete.is_set() and len(processed_batches) == len( + current_batch_ids + ): elapsed = time.time() - start_time logger.info( f"All batches completed in {elapsed:.1f}s: " From 28fe63732d6cb12f6567c710aa79137433948c0a Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Thu, 29 Jan 2026 16:28:46 +0100 Subject: [PATCH 18/21] fixed cached assets not being counted to completed --- .../datapoints/_asset_upload_orchestrator.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py index a211d368d..4920f9f47 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -86,6 +86,15 @@ def upload_all_assets( urls, files = self._separate_urls_and_files(all_assets) uncached_urls, uncached_files = self._filter_and_log_cached_assets(urls, files) + # Notify callback about cached assets (already complete) + cached_assets = [] + cached_assets.extend([url for url in urls if url not in uncached_urls]) + cached_assets.extend([file for file in files if file not in uncached_files]) + + if cached_assets and asset_completion_callback: + logger.debug(f"Notifying callback of {len(cached_assets)} cached asset(s)") + asset_completion_callback(cached_assets) + if len(uncached_urls) + len(uncached_files) == 0: logger.debug("All assets cached, nothing to upload") return [] From d56c3548b492441259df56d8393dc58862d94c40 Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Thu, 29 Jan 2026 16:50:21 +0100 Subject: [PATCH 19/21] removed access to private method --- .../datapoints/_asset_upload_orchestrator.py | 7 +++++-- .../rapidata_client/datapoints/_single_flight_cache.py | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py index 4920f9f47..8d404f500 100644 --- a/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py +++ b/src/rapidata/rapidata_client/datapoints/_asset_upload_orchestrator.py @@ -12,6 +12,7 @@ BatchAssetUploader, ) from rapidata.rapidata_client.exceptions.failed_upload import FailedUpload +from rapidata.rapidata_client.datapoints._single_flight_cache import SingleFlightCache if TYPE_CHECKING: from rapidata.rapidata_client.datapoints._datapoint import Datapoint @@ -239,7 +240,9 @@ def _log_upload_results(self, failed_uploads: list[FailedUpload[str]]) -> None: else: logger.info("Step 1/2: All assets uploaded successfully") - def _filter_uncached(self, assets: list[str], cache) -> list[str]: + def _filter_uncached( + self, assets: list[str], cache: SingleFlightCache + ) -> list[str]: """Filter out assets that are already cached.""" uncached = [] for asset in assets: @@ -251,7 +254,7 @@ def _filter_uncached(self, assets: list[str], cache) -> list[str]: cache_key = self.asset_uploader.get_file_cache_key(asset) # Check if in cache - if cache_key not in cache._storage: + if cache_key not in cache.get_storage(): uncached.append(asset) except Exception as e: # If cache check fails, include in upload list diff --git a/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py b/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py index 8d13b7788..e3f3b1c8e 100644 --- a/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py +++ b/src/rapidata/rapidata_client/datapoints/_single_flight_cache.py @@ -29,6 +29,10 @@ def set_storage(self, storage: dict[str, str] | FanoutCache) -> None: except Exception: pass + def get_storage(self) -> dict[str, str] | FanoutCache: + """Get the cache storage.""" + return self._storage + def get_or_fetch( self, key: str, From 6bc756d65aa6f5e63a63db8a547a6883fbd18b3e Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Thu, 29 Jan 2026 17:22:51 +0100 Subject: [PATCH 20/21] fixed log that would be too big --- src/rapidata/rapidata_client/order/rapidata_order_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rapidata/rapidata_client/order/rapidata_order_manager.py b/src/rapidata/rapidata_client/order/rapidata_order_manager.py index c80026151..f251e2d2e 100644 --- a/src/rapidata/rapidata_client/order/rapidata_order_manager.py +++ b/src/rapidata/rapidata_client/order/rapidata_order_manager.py @@ -102,7 +102,7 @@ def _create_general_order( "Creating order with parameters: name %s, workflow %s, datapoints %s, responses_per_datapoint %s, validation_set_id %s, confidence_threshold %s, filters %s, settings %s, selections %s", name, workflow, - datapoints, + len(datapoints), responses_per_datapoint, validation_set_id, confidence_threshold, From 841b6427c30fd462ba88f1793a0ebfba656f6e3c Mon Sep 17 00:00:00 2001 From: Lino Giger <68745352+LinoGiger@users.noreply.github.com> Date: Thu, 29 Jan 2026 18:00:58 +0100 Subject: [PATCH 21/21] test batch intterupt --- .../datapoints/_batch_asset_uploader.py | 164 +++++++++++++----- 1 file changed, 117 insertions(+), 47 deletions(-) diff --git a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py index 03f0447da..f5ca14c30 100644 --- a/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py +++ b/src/rapidata/rapidata_client/datapoints/_batch_asset_uploader.py @@ -32,6 +32,7 @@ def __init__(self, openapi_service: OpenAPIService) -> None: self.openapi_service = openapi_service self.asset_uploader = AssetUploader(openapi_service) self.url_cache = AssetUploader._url_cache + self._interrupted = False def batch_upload_urls( self, @@ -66,59 +67,75 @@ def batch_upload_urls( batch_to_urls: dict[str, list[str]] = {} submission_complete = threading.Event() - # Submit batches in background thread - def submit_batches_background(): - """Submit all batches and signal completion.""" - for batch_idx, batch in enumerate(batches): - try: - result = self.openapi_service.batch_upload_api.asset_batch_upload_post( - create_batch_upload_endpoint_input=CreateBatchUploadEndpointInput( - urls=batch + try: + # Submit batches in background thread + def submit_batches_background(): + """Submit all batches and signal completion.""" + for batch_idx, batch in enumerate(batches): + # Check if interrupted before submitting next batch + if self._interrupted: + logger.debug("Batch submission stopped due to interruption") + break + + try: + result = self.openapi_service.batch_upload_api.asset_batch_upload_post( + create_batch_upload_endpoint_input=CreateBatchUploadEndpointInput( + urls=batch + ) ) - ) - batch_id = result.batch_upload_id + batch_id = result.batch_upload_id - # Add to shared collections (thread-safe) - with batch_ids_lock: - batch_ids.append(batch_id) - batch_to_urls[batch_id] = batch + # Add to shared collections (thread-safe) + with batch_ids_lock: + batch_ids.append(batch_id) + batch_to_urls[batch_id] = batch - logger.debug( - f"Submitted batch {batch_idx + 1}/{len(batches)}: {batch_id}" + logger.debug( + f"Submitted batch {batch_idx + 1}/{len(batches)}: {batch_id}" + ) + except Exception as e: + logger.error(f"Failed to submit batch {batch_idx + 1}: {e}") + + # Signal that all batches have been submitted + submission_complete.set() + with batch_ids_lock: + logger.info( + f"Successfully submitted {len(batch_ids)}/{len(batches)} batches" ) - except Exception as e: - logger.error(f"Failed to submit batch {batch_idx + 1}: {e}") - # Signal that all batches have been submitted - submission_complete.set() - with batch_ids_lock: - logger.info( - f"Successfully submitted {len(batch_ids)}/{len(batches)} batches" - ) + # Start background submission + submission_thread = threading.Thread( + target=submit_batches_background, daemon=True + ) + submission_thread.start() + + # Wait for at least one batch to be submitted before starting poll + while len(batch_ids) == 0 and not submission_complete.is_set(): + time.sleep(0.5) + + if len(batch_ids) == 0: + logger.error("No batches were successfully submitted") + return self._create_submission_failures(urls) + + # Poll until complete (will handle dynamically growing batch list) + return self._poll_until_complete( + batch_ids, + batch_to_urls, + batch_ids_lock, + submission_complete, + progress_callback, + completion_callback, + ) - # Start background submission - submission_thread = threading.Thread( - target=submit_batches_background, daemon=True - ) - submission_thread.start() - - # Wait for at least one batch to be submitted before starting poll - while len(batch_ids) == 0 and not submission_complete.is_set(): - time.sleep(0.5) - - if len(batch_ids) == 0: - logger.error("No batches were successfully submitted") - return self._create_submission_failures(urls) - - # Poll until complete (will handle dynamically growing batch list) - return self._poll_until_complete( - batch_ids, - batch_to_urls, - batch_ids_lock, - submission_complete, - progress_callback, - completion_callback, - ) + except KeyboardInterrupt: + logger.warning("Batch upload interrupted by user (Ctrl+C)") + self._interrupted = True + raise # Re-raise to propagate interruption + + finally: + # Cleanup: abort batches if interrupted + if self._interrupted: + self._abort_batches(batch_ids, batch_ids_lock) def _split_into_batches(self, urls: list[str]) -> list[list[str]]: """Split URLs into batches of configured size.""" @@ -161,6 +178,11 @@ def _poll_until_complete( all_failures: list[FailedUpload[str]] = [] while True: + # Check for interruption at start of each iteration + if self._interrupted: + logger.debug("Polling stopped due to interruption") + break + # Get current batch IDs (thread-safe) with batch_ids_lock: current_batch_ids = batch_ids.copy() @@ -218,6 +240,9 @@ def _poll_until_complete( logger.error(f"Error polling batch status: {e}") time.sleep(poll_interval) + # Return failures collected so far (reached via break on interruption) + return all_failures + def _update_progress( self, status: GetBatchUploadStatusEndpointOutput, @@ -317,3 +342,48 @@ def _create_submission_failures(self, urls: list[str]) -> list[FailedUpload[str] ) for url in urls ] + + def _abort_batches( + self, + batch_ids: list[str], + batch_ids_lock: threading.Lock, + ) -> None: + """ + Abort all submitted batches by calling the abort endpoint. + + This method is called during cleanup when the upload process is interrupted. + It attempts to abort all batches that were successfully submitted. + + Args: + batch_ids: Shared list of batch IDs (thread-safe access required). + batch_ids_lock: Lock protecting batch_ids list. + """ + # Get snapshot of current batch IDs + with batch_ids_lock: + batches_to_abort = batch_ids.copy() + + if not batches_to_abort: + logger.info("No batches to abort") + return + + logger.info( + f"Aborting {len(batches_to_abort)} batch(es) due to interruption..." + ) + + abort_successes = 0 + abort_failures = 0 + + for batch_id in batches_to_abort: + try: + self.openapi_service.batch_upload_api.asset_batch_upload_batch_upload_id_abort_post( + batch_upload_id=batch_id + ) + abort_successes += 1 + logger.debug(f"Successfully aborted batch: {batch_id}") + except Exception as e: + abort_failures += 1 + logger.warning(f"Failed to abort batch {batch_id}: {e}") + + logger.info( + f"Batch abort completed: {abort_successes} succeeded, {abort_failures} failed" + )