diff --git a/src/openmind_hub/plugins/openmind/om_api.py b/src/openmind_hub/plugins/openmind/om_api.py index e9b5651e18b5fb088848efe060bf282ee8a4746a..2c2e8045e92fbe835785ce4c4cad22ad28d9a82b 100644 --- a/src/openmind_hub/plugins/openmind/om_api.py +++ b/src/openmind_hub/plugins/openmind/om_api.py @@ -15,6 +15,7 @@ import base64 import gc import os +import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, asdict from datetime import datetime @@ -1384,11 +1385,21 @@ class OmApi: params = {"ref": revision} headers = self.build_om_headers(token=token) path = f"{self.endpoint}/api/v1/dataset/{repo_id}" - r = get_session().get(path, headers=headers, params=params, timeout=timeout or DEFAULT_REQUEST_TIMEOUT) - om_raise_for_status(r) - data = r.json().get("data") - data["sha"] = self.get_repo_last_commit(repo_id=repo_id, token=token, revision=revision).oid - return DatasetInfo(**data) + + MAX_RETRIES = 3 + retry_count = 0 + while retry_count < MAX_RETRIES: + try: + r = get_session().get(path, headers=headers, params=params, timeout=timeout or DEFAULT_REQUEST_TIMEOUT) + om_raise_for_status(r) + data = r.json().get("data") + data["sha"] = self.get_repo_last_commit(repo_id=repo_id, token=token, revision=revision).oid + return DatasetInfo(**data) + except Exception as e: + retry_count += 1 + if retry_count == MAX_RETRIES: + raise e + time.sleep(2**retry_count) @validate_om_hub_args def space_info( diff --git a/src/openmind_hub/plugins/openmind/utils/_validators.py b/src/openmind_hub/plugins/openmind/utils/_validators.py index ce4fb18e8f46d7f1b9b083d7a2e2f2f5a3e0d1d8..75ffeee7fd6660c066a6d5667d12c8d8182c479d 100644 --- a/src/openmind_hub/plugins/openmind/utils/_validators.py +++ b/src/openmind_hub/plugins/openmind/utils/_validators.py @@ -101,7 +101,7 @@ def validate_gitcode_repo_id(repo_id: str) -> str: if os.getenv("OPENMIND_PLATFORM") != "gitcode": raise OMValidationError("Repo id must be in the form 'repo_name' or 'owner/repo_name'") - repo_id = repo_id[: repo_id.find("/")] + "%2F" + repo_id[repo_id.find("/") + 1:] + repo_id = repo_id[: repo_id.find("/")] + "%2F" + repo_id[repo_id.find("/") + 1 :] owner, repo = repo_id.split("/") if "/" in repo_id else (None, repo_id)