diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index bf5bfda2..00000000 --- a/.coveragerc +++ /dev/null @@ -1,31 +0,0 @@ -# .coveragerc to control coverage.py -[run] -branch = True -source = aleph.sdk -# omit = bad_file.py - -[paths] -source = - src/ - */site-packages/ - -[report] -# Regexes for lines to exclude from consideration -exclude_lines = - # Have to re-enable the standard pragma - pragma: no cover - - # Don't complain about missing debug-only code: - def __repr__ - if self\.debug - - # Don't complain if tests don't hit defensive assertion code: - raise AssertionError - raise NotImplementedError - - # Don't complain if non-runnable code isn't run: - if 0: - if __name__ == .__main__.: - - # Don't complain about ineffective code: - pass diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..dbc6b016 --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +# To modify src/aleph/sdk/conf.py, create a .env file and add: +# ALEPH_= +# To modify active & rpc fields in CHAINS, follow this example: +# ALEPH_CHAINS_SEPOLIA_ACTIVE=True +# ALEPH_CHAINS_SEPOLIA_RPC=https://... \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..8ceeaf03 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: +- package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/build-wheels.yml b/.github/workflows/build-wheels.yml index 241d738c..96d828ad 100644 --- a/.github/workflows/build-wheels.yml +++ b/.github/workflows/build-wheels.yml @@ -11,12 +11,13 @@ on: jobs: build: strategy: + fail-fast: false matrix: - os: [macos-11, macos-12, ubuntu-20.04, ubuntu-22.04] + os: [macos-14, ubuntu-22.04, ubuntu-24.04] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Workaround github issue https://github.com/actions/runner-images/issues/7192 if: startsWith(matrix.os, 'ubuntu-') @@ -26,13 +27,13 @@ jobs: if: startsWith(matrix.os, 'macos') uses: actions/setup-python@v2 with: - python-version: 3.11 + python-version: 3.12 - name: Cache dependencies - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip - key: ${{ runner.os }}-build-wheels-${{ hashFiles('setup.cfg', 'setup.py') }} + key: ${{ runner.os }}-build-wheels-${{ hashFiles('pyproject.toml') }} restore-keys: | ${{ runner.os }}-build-wheels- @@ -40,8 +41,7 @@ jobs: if: startsWith(matrix.os, 'macos-') run: | brew update - brew tap cuber/homebrew-libsecp256k1 - brew install libsecp256k1 + brew install secp256k1 - name: Install required system packages only for Ubuntu Linux if: startsWith(matrix.os, 'ubuntu-') @@ -50,22 +50,23 @@ jobs: sudo apt-get -y upgrade sudo apt-get install -y libsecp256k1-dev - - name: Install required Python packages + - name: Install Hatch run: | - python3 -m pip install --upgrade build - python3 -m pip install --user --upgrade twine + python3 -m venv /tmp/venv + /tmp/venv/bin/python3 -m pip install --upgrade hatch - name: Build source and wheel packages run: | - python3 -m build + /tmp/venv/bin/python3 -m hatch build - name: Install the Python wheel run: | - python3 -m pip install dist/aleph_sdk_python-*.whl + /tmp/venv/bin/python3 -m pip install dist/aleph_sdk_python-*.whl + + - name: Install/upgrade `setuptools` + run: /tmp/venv/bin/python3 -m pip install --upgrade setuptools - name: Import and use the package - # macos tests fail this step because they use Python 3.11, which is not yet supported by our dependencies - if: startsWith(matrix.os, 'ubuntu-') run: | - python3 -c "import aleph.sdk" - python3 -c "from aleph.sdk.chains.ethereum import get_fallback_account; get_fallback_account()" + /tmp/venv/bin/python3 -c "import aleph.sdk" + /tmp/venv/bin/python3 -c "from aleph.sdk.chains.ethereum import get_fallback_account; get_fallback_account()" diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 6576a34e..16ec4e91 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -8,44 +8,30 @@ on: jobs: code-quality: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Workaround github issue https://github.com/actions/runner-images/issues/7192 run: sudo echo RESET grub-efi/install_devices | sudo debconf-communicate grub-pc + - name: Install system dependencies + run: | + sudo apt-get install -y python3-pip libsecp256k1-dev + - name: Cache dependencies - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip - key: ${{ runner.os }}-code-quality-${{ hashFiles('setup.cfg', 'setup.py') }} + key: ${{ runner.os }}-code-quality-${{ hashFiles('pyproject.toml') }} restore-keys: | ${{ runner.os }}-code-quality- - - name: Install required system packages only for Ubuntu Linux - run: | - sudo apt-get update - sudo apt-get -y upgrade - sudo apt-get install -y libsecp256k1-dev - - - name: Install required Python packages - run: | - python3 -m pip install -e .[testing,ethereum] - - - name: Test with Black + - name: Install python dependencies run: | - black --check ./src/ ./tests/ ./examples/ + python3 -m venv /tmp/venv + /tmp/venv/bin/pip install hatch - - name: Test with isort - run: | - isort --check-only ./src/ ./tests/ ./examples/ - - - name: Test with MyPy - run: | - mypy --config-file ./mypy.ini ./src/ ./tests/ ./examples/ - - - name: Test with flake8 - run: | - flake8 ./src/ ./tests/ ./examples/ + - name: Run Hatch lint + run: /tmp/venv/bin/hatch run linting:all diff --git a/.github/workflows/pr-rating.yml b/.github/workflows/pr-rating.yml index 8f42647d..1378b687 100644 --- a/.github/workflows/pr-rating.yml +++ b/.github/workflows/pr-rating.yml @@ -13,7 +13,7 @@ jobs: if: github.event.pull_request.draft == false steps: - name: PR Difficulty Rating - uses: rate-my-pr/rate@v1 + uses: rate-my-pr/rate@v2 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} LLAMA_URL: ${{ secrets.LLAMA_URL }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..5bbaeb4a --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,48 @@ +--- +name: Publish to PyPI + +on: + push: + tags: + - "*" + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install build dependencies + run: pip install hatch hatch-vcs + + - name: Build package + run: hatch build + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + + publish: + needs: build + runs-on: ubuntu-latest + environment: pypi + permissions: + id-token: write + steps: + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/pytest-docker.yml b/.github/workflows/pytest-docker.yml deleted file mode 100644 index d6e0759d..00000000 --- a/.github/workflows/pytest-docker.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: Test using Pytest in Docker - -on: - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - build: - strategy: - matrix: - image: [ "python-3.9", "python-3.10", "python-3.11", "ubuntu-20.04", "ubuntu-22.04" ] - runs-on: ubuntu-22.04 - - steps: - - uses: actions/checkout@v3 - - # Use GitHub's Docker registry to cache intermediate layers - - run: echo ${{ secrets.GITHUB_TOKEN }} | docker login docker.pkg.github.com -u $GITHUB_ACTOR --password-stdin - - run: docker pull docker.pkg.github.com/$GITHUB_REPOSITORY/aleph-sdk-python-build-cache || true - - - name: Build the Docker image - run: | - git fetch --prune --unshallow --tags - docker build . -t aleph-sdk-python:${GITHUB_REF##*/} -f docker/${{matrix.image}}.dockerfile --cache-from=docker.pkg.github.com/$GITHUB_REPOSITORY/aleph-sdk-python-build-cache - - - name: Push the image on GitHub's repository - run: docker tag aleph-sdk-python:${GITHUB_REF##*/} docker.pkg.github.com/$GITHUB_REPOSITORY/aleph-sdk-python:${GITHUB_REF##*/} && docker push docker.pkg.github.com/$GITHUB_REPOSITORY/aleph-sdk-python:${GITHUB_REF##*/} || true - - - name: Cache the image on GitHub's repository - run: docker tag aleph-sdk-python:${GITHUB_REF##*/} docker.pkg.github.com/$GITHUB_REPOSITORY/aleph-sdk-python-build-cache && docker push docker.pkg.github.com/$GITHUB_REPOSITORY/aleph-sdk-python-build-cache || true - - - name: Pytest in the Docker image - run: | - docker run --entrypoint /opt/venv/bin/pytest aleph-sdk-python:${GITHUB_REF##*/} /opt/aleph-sdk-python/ diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 00000000..c2e08466 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,61 @@ +name: Test/Coverage with Python + +on: + push: + pull_request: + branches: + - main + schedule: + - cron: '4 0 * * *' + +jobs: + tests: + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + os: [ubuntu-22.04, ubuntu-24.04, macos-14, macos-15] + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: "apt-get install" + run: | + sudo apt-get update + sudo apt-get install -y python3-pip libsodium-dev libgmp-dev + if: runner.os == 'Linux' + + - run: | + brew install libsodium + echo "DYLD_LIBRARY_PATH=$(brew --prefix libsodium)/lib" >> $GITHUB_ENV + if: runner.os == 'macOS' + + # Workaround to avoid building pyobjc-core on macOS14 + Python 3.9. Support for Python 3.9 will be dropped + # once we support a more recent version of Python on functions. + - name: Avoid building pyobjc-core on macOS+Py3.9 + if: runner.os == 'macOS' && matrix.python-version == '3.9' + run: | + echo "pyobjc-core<12" > /tmp/constraints.txt + echo "PIP_CONSTRAINT=/tmp/constraints.txt" >> $GITHUB_ENV + + - name: "Install Hatch" + run: | + python3 -m venv /tmp/venv + /tmp/venv/bin/python -m pip install --upgrade pip hatch coverage + + - name: "Run Tests" + run: | + /tmp/venv/bin/pip freeze + /tmp/venv/bin/hatch run testing:pip freeze + /tmp/venv/bin/hatch run testing:test + + - run: /tmp/venv/bin/hatch run testing:cov + + - uses: codecov/codecov-action@v4.0.1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: aleph-im/aleph-sdk-python diff --git a/.gitignore b/.gitignore index c4734889..5ab655ca 100644 --- a/.gitignore +++ b/.gitignore @@ -47,6 +47,11 @@ MANIFEST # Per-project virtualenvs .venv*/ +venv/* **/device.key +# environment variables +.config.json +.env.local + .gitsigners diff --git a/AUTHORS.rst b/AUTHORS.rst index a7c7459a..65119268 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -5,3 +5,4 @@ Contributors * Henry Taieb * Hugo Herter * Moshe Malawach +* Mike Hukiewitz \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst deleted file mode 100644 index 30607221..00000000 --- a/CHANGELOG.rst +++ /dev/null @@ -1,8 +0,0 @@ -========= -Changelog -========= - -Version 0.1 -=========== - -- Converted from minialeph \ No newline at end of file diff --git a/README.md b/README.md index e2e74cf9..00ec940d 100644 --- a/README.md +++ b/README.md @@ -5,11 +5,9 @@ Python SDK for the Aleph.im network, next generation network of decentralized bi Development follows the [Aleph Whitepaper](https://github.com/aleph-im/aleph-whitepaper). ## Documentation -Documentation (albeit still vastly incomplete as it is a work in progress) can be found at [http://aleph-sdk-python.readthedocs.io/](http://aleph-sdk-python.readthedocs.io/) or built from this repo with: +The latest documentation, albeit incomplete, is available at [https://docs.aleph.im/libraries/python-sdk/](https://docs.aleph.im/libraries/python-sdk/). -```shell -$ python setup.py docs -``` +For the full documentation, please refer to the docstrings in the source code. ## Requirements ### Linux @@ -23,26 +21,70 @@ Using some chains may also require installing `libgmp3-dev`. ### macOs This project does not support Python 3.12 on macOS. Please use Python 3.11 instead. ```shell -$ brew tap cuber/homebrew-libsecp256k1 -$ brew install libsecp256k1 +$ brew install secp256k1 ``` ## Installation Using pip and [PyPI](https://pypi.org/project/aleph-sdk-python/): ```shell -$ pip install aleph-sdk-python[ethereum,solana,tezos] +$ pip install aleph-sdk-python +``` + +### Additional dependencies +Some functionalities require additional dependencies. They can be installed like this: + +```shell +$ pip install aleph-sdk-python[solana, dns] ``` +The following extra dependencies are available: +- `solana` for Solana accounts and signatures +- `cosmos` for Substrate/Cosmos accounts and signatures +- `nuls2` for NULS2 accounts and signatures +- `polkadot` for Polkadot accounts and signatures +- `ledger` for Ledger hardware wallet support, see [Usage with LedgerHQ hardware](#usage-with-ledgerhq-hardware) +- `mqtt` for MQTT-related functionalities, see [examples/mqtt.py](examples/mqtt.py) +- `docs` for building the documentation, see [Documentation](#documentation) +- `dns` for DNS-related functionalities +- `all` installs all extra dependencies + + ## Installation for development -To install from source and still be able to modify the source code: +Setup a virtual environment using [hatch](https://hatch.pypa.io/): +```shell +$ hatch shell +``` + +Then install the SDK from source with all extra dependencies: ```shell -$ pip install -e .[testing] +$ pip install -e .[all] ``` -or + +### Running tests & Hatch scripts +You can use the test env defined for hatch to run the tests: + +```shell +$ hatch run testing:run +``` + +See `hatch env show` for more information about all the environments and their scripts. + +### Generating the documentation [DEPRECATED] +The documentation is built using [Sphinx](https://www.sphinx-doc.org/). + +To build the documentation, install the SDK with the `docs` extra dependencies: + +```shell +$ pip install -e .[docs] +``` + +Then build the documentation: + ```shell -$ python setup.py develop +$ cd docs +$ make html ``` ## Usage with LedgerHQ hardware diff --git a/docker/python-3.10.dockerfile b/docker/python-3.10.dockerfile deleted file mode 100644 index 3af183ca..00000000 --- a/docker/python-3.10.dockerfile +++ /dev/null @@ -1,39 +0,0 @@ -FROM python:3.10-bullseye -MAINTAINER The aleph.im project - -RUN apt-get update && apt-get -y upgrade && apt-get install -y \ - libsecp256k1-dev \ - && rm -rf /var/lib/apt/lists/* - -RUN useradd -s /bin/bash --create-home user -RUN mkdir /opt/venv -RUN mkdir /opt/aleph-sdk-python/ -RUN chown user:user /opt/venv -RUN chown user:user /opt/aleph-sdk-python - -USER user -RUN python3 -m venv /opt/venv -ENV PATH="/opt/venv/bin:$PATH" -ENV PATH="/opt/venv/bin:$PATH" - -RUN pip install --upgrade pip wheel twine - -# Preinstall dependencies for faster steps -RUN pip install --upgrade secp256k1 coincurve aiohttp eciespy python-magic typer -RUN pip install --upgrade 'aleph-message~=0.4.0' eth_account pynacl base58 -RUN pip install --upgrade pytest pytest-cov pytest-asyncio mypy types-setuptools pytest-asyncio fastapi httpx requests - -WORKDIR /opt/aleph-sdk-python/ -COPY . . -USER root -RUN chown -R user:user /opt/aleph-sdk-python - -RUN git config --global --add safe.directory /opt/aleph-sdk-python -RUN pip install -e .[testing,ethereum,solana,tezos,ledger] - -RUN mkdir /data -RUN chown user:user /data -ENV ALEPH_PRIVATE_KEY_FILE=/data/secret.key - -WORKDIR /home/user -USER user diff --git a/docker/python-3.11.dockerfile b/docker/python-3.11.dockerfile deleted file mode 100644 index 644195a7..00000000 --- a/docker/python-3.11.dockerfile +++ /dev/null @@ -1,39 +0,0 @@ -FROM python:3.11-bullseye -MAINTAINER The aleph.im project - -RUN apt-get update && apt-get -y upgrade && apt-get install -y \ - libsecp256k1-dev \ - && rm -rf /var/lib/apt/lists/* - -RUN useradd -s /bin/bash --create-home user -RUN mkdir /opt/venv -RUN mkdir /opt/aleph-sdk-python/ -RUN chown user:user /opt/venv -RUN chown user:user /opt/aleph-sdk-python - -USER user -RUN python3 -m venv /opt/venv -ENV PATH="/opt/venv/bin:$PATH" -ENV PATH="/opt/venv/bin:$PATH" - -RUN pip install --upgrade pip wheel twine - -# Preinstall dependencies for faster steps -RUN pip install --upgrade secp256k1 coincurve aiohttp eciespy python-magic typer -RUN pip install --upgrade 'aleph-message~=0.4.0' pynacl base58 -RUN pip install --upgrade pytest pytest-cov pytest-asyncio mypy types-setuptools pytest-asyncio fastapi httpx requests - -WORKDIR /opt/aleph-sdk-python/ -COPY . . -USER root -RUN chown -R user:user /opt/aleph-sdk-python - -RUN git config --global --add safe.directory /opt/aleph-sdk-python -RUN pip install -e .[testing,ethereum,solana,tezos,ledger] - -RUN mkdir /data -RUN chown user:user /data -ENV ALEPH_PRIVATE_KEY_FILE=/data/secret.key - -WORKDIR /home/user -USER user diff --git a/docker/python-3.9.dockerfile b/docker/python-3.9.dockerfile deleted file mode 100644 index ff6d3c41..00000000 --- a/docker/python-3.9.dockerfile +++ /dev/null @@ -1,39 +0,0 @@ -FROM python:3.9-bullseye -MAINTAINER The aleph.im project - -RUN apt-get update && apt-get -y upgrade && apt-get install -y \ - libsecp256k1-dev \ - && rm -rf /var/lib/apt/lists/* - -RUN useradd -s /bin/bash --create-home user -RUN mkdir /opt/venv -RUN mkdir /opt/aleph-sdk-python/ -RUN chown user:user /opt/venv -RUN chown user:user /opt/aleph-sdk-python - -USER user -RUN python3 -m venv /opt/venv -ENV PATH="/opt/venv/bin:$PATH" -ENV PATH="/opt/venv/bin:$PATH" - -RUN pip install --upgrade pip wheel twine - -# Preinstall dependencies for faster steps -RUN pip install --upgrade secp256k1 coincurve aiohttp eciespy python-magic typer -RUN pip install --upgrade 'aleph-message~=0.4.0' eth_account pynacl base58 -RUN pip install --upgrade pytest pytest-cov pytest-asyncio mypy types-setuptools pytest-asyncio fastapi httpx requests - -WORKDIR /opt/aleph-sdk-python/ -COPY . . -USER root -RUN chown -R user:user /opt/aleph-sdk-python - -RUN git config --global --add safe.directory /opt/aleph-sdk-python -RUN pip install -e .[testing,ethereum,solana,tezos,ledger] - -RUN mkdir /data -RUN chown user:user /data -ENV ALEPH_PRIVATE_KEY_FILE=/data/secret.key - -WORKDIR /home/user -USER user diff --git a/docker/ubuntu-20.04.dockerfile b/docker/ubuntu-20.04.dockerfile deleted file mode 100644 index cb0d7c7e..00000000 --- a/docker/ubuntu-20.04.dockerfile +++ /dev/null @@ -1,44 +0,0 @@ -FROM ubuntu:20.04 -MAINTAINER The aleph.im project - -RUN apt-get update && apt-get -y upgrade && apt-get install -y \ - libsecp256k1-dev \ - python3-dev \ - python3-venv \ - git \ - build-essential \ - libgmp3-dev \ - && rm -rf /var/lib/apt/lists/* - -RUN useradd -s /bin/bash --create-home user -RUN mkdir /opt/venv -RUN mkdir /opt/aleph-sdk-python/ -RUN chown user:user /opt/venv -RUN chown user:user /opt/aleph-sdk-python - -USER user -RUN python3 -m venv /opt/venv -ENV PATH="/opt/venv/bin:$PATH" -ENV PATH="/opt/venv/bin:$PATH" - -RUN pip install --upgrade pip wheel twine - -# Preinstall dependencies for faster steps -RUN pip install --upgrade secp256k1 coincurve aiohttp eciespy python-magic typer -RUN pip install --upgrade 'aleph-message~=0.4.0' eth_account pynacl base58 -RUN pip install --upgrade pytest pytest-cov pytest-asyncio mypy types-setuptools pytest-asyncio fastapi httpx requests - -WORKDIR /opt/aleph-sdk-python/ -COPY . . -USER root -RUN chown -R user:user /opt/aleph-sdk-python - -RUN git config --global --add safe.directory /opt/aleph-sdk-python -RUN pip install -e .[testing,ethereum,solana,tezos,ledger] - -RUN mkdir /data -RUN chown user:user /data -ENV ALEPH_PRIVATE_KEY_FILE=/data/secret.key - -WORKDIR /home/user -USER user diff --git a/docker/ubuntu-22.04.dockerfile b/docker/ubuntu-22.04.dockerfile deleted file mode 100644 index 8e44e482..00000000 --- a/docker/ubuntu-22.04.dockerfile +++ /dev/null @@ -1,44 +0,0 @@ -FROM ubuntu:22.04 -MAINTAINER The aleph.im project - -RUN apt-get update && apt-get -y upgrade && apt-get install -y \ - libsecp256k1-dev \ - python3-dev \ - python3-venv \ - git \ - build-essential \ - libgmp3-dev \ - && rm -rf /var/lib/apt/lists/* - -RUN useradd -s /bin/bash --create-home user -RUN mkdir /opt/venv -RUN mkdir /opt/aleph-sdk-python/ -RUN chown user:user /opt/venv -RUN chown user:user /opt/aleph-sdk-python - -USER user -RUN python3 -m venv /opt/venv -ENV PATH="/opt/venv/bin:$PATH" -ENV PATH="/opt/venv/bin:$PATH" - -RUN pip install --upgrade pip wheel twine - -# Preinstall dependencies for faster steps -RUN pip install --upgrade secp256k1 coincurve aiohttp eciespy python-magic typer -RUN pip install --upgrade 'aleph-message~=0.4.0' eth_account pynacl base58 -RUN pip install --upgrade pytest pytest-cov pytest-asyncio mypy types-setuptools pytest-asyncio fastapi httpx requests - -WORKDIR /opt/aleph-sdk-python/ -COPY . . -USER root -RUN chown -R user:user /opt/aleph-sdk-python - -RUN git config --global --add safe.directory /opt/aleph-sdk-python -RUN pip install -e .[testing,ethereum,solana,tezos,ledger] - -RUN mkdir /data -RUN chown user:user /data -ENV ALEPH_PRIVATE_KEY_FILE=/data/secret.key - -WORKDIR /home/user -USER user diff --git a/docker/with-ipfs.dockerfile b/docker/with-ipfs.dockerfile index e9625f18..507ee0ea 100644 --- a/docker/with-ipfs.dockerfile +++ b/docker/with-ipfs.dockerfile @@ -29,7 +29,7 @@ RUN mkdir /opt/aleph-sdk-python/ WORKDIR /opt/aleph-sdk-python/ COPY . . -RUN pip install -e .[testing,ethereum] +RUN pip install -e .[testing] # - User 'aleph' to run the code itself diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 87a87653..00000000 --- a/mypy.ini +++ /dev/null @@ -1,67 +0,0 @@ -# Global options: - -[mypy] -python_version = 3.8 - -mypy_path = src - -exclude = conftest.py - - -show_column_numbers = True - -# Suppressing errors -# Shows errors related to strict None checking, if the global strict_optional flag is enabled -strict_optional = True -no_implicit_optional = True - -# Import discovery -# Suppresses error messages about imports that cannot be resolved -ignore_missing_imports = True -# Forces import to reference the original source file -no_implicit_reexport = True -# show error messages from unrelated files -follow_imports = silent -follow_imports_for_stubs = False - - -# Disallow dynamic typing -# Disallows usage of types that come from unfollowed imports -disallow_any_unimported = False -# Disallows all expressions in the module that have type Any -disallow_any_expr = False -# Disallows functions that have Any in their signature after decorator transformation. -disallow_any_decorated = False -# Disallows explicit Any in type positions such as type annotations and generic type parameters. -disallow_any_explicit = False -# Disallows usage of generic types that do not specify explicit type parameters. -disallow_any_generics = False -# Disallows subclassing a value of type Any. -disallow_subclassing_any = False - -# Untyped definitions and calls -# Disallows calling functions without type annotations from functions with type annotations. -disallow_untyped_calls = False -# Disallows defining functions without type annotations or with incomplete type annotations -disallow_untyped_defs = False -# Disallows defining functions with incomplete type annotations. -check_untyped_defs = True -# Type-checks the interior of functions without type annotations. -disallow_incomplete_defs = False -# Reports an error whenever a function with type annotations is decorated with a decorator without annotations. -disallow_untyped_decorators = False - -# Prohibit comparisons of non-overlapping types (ex: 42 == "no") -strict_equality = True - -# Configuring warnings -# Warns about unneeded # type: ignore comments. -warn_unused_ignores = True -# Shows errors for missing return statements on some execution paths. -warn_no_return = True -# Shows a warning when returning a value with type Any from a function declared with a non- Any return type. -warn_return_any = False - -# Miscellaneous strictness flags -# Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition. -allow_redefinition = True diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..ad051b58 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,255 @@ +[build-system] +build-backend = "hatchling.build" + +requires = [ "hatch-vcs", "hatchling" ] + +[project] +name = "aleph-sdk-python" +description = "Lightweight Python Client library for the Aleph.im network" +readme = "README.md" +license = { file = "LICENSE.txt" } +authors = [ + { name = "Aleph.im Team", email = "hello@aleph.im" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Framework :: aiohttp", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Software Development :: Libraries", +] +dynamic = [ "version" ] +dependencies = [ + "aiohttp>=3.8.3", + "aioresponses>=0.7.6", + "aleph-message>=1.1", + "aleph-superfluid>=0.3", + "base58==2.1.1", # Needed now as default with _load_account changement + "coincurve; python_version>='3.9'", + "coincurve>=19; python_version>='3.9'", + "eth-abi>=5.0.1; python_version>='3.9'", + "eth-typing>=5.0.1", + "jwcrypto==1.5.6", + "ledgerblue>=0.1.48", + "ledgereth>=0.10", + "pydantic>=2,<3", + "pydantic-settings>=2", + "pynacl==1.5", # Needed now as default with _load_account changement + "python-magic", + "typing-extensions", + "web3>=7.10", +] + +optional-dependencies.all = [ + "aleph-sdk-python[cosmos,dns,docs,ledger,mqtt,nuls2,substrate,solana,tezos,encryption]", +] +optional-dependencies.cosmos = [ + "cosmospy", +] +optional-dependencies.dns = [ + "aiodns", + "tldextract", +] +optional-dependencies.docs = [ + "sphinxcontrib-plantuml", +] +optional-dependencies.encryption = [ + "eciespy; python_version<'3.11'", + "eciespy>=0.3.13; python_version>='3.11'", +] +optional-dependencies.ledger = [ + "ledgereth==0.10", +] +optional-dependencies.mqtt = [ + "aiomqtt<=0.1.3", + "certifi", + "click", +] +optional-dependencies.nuls2 = [ + "aleph-nuls2", +] +optional-dependencies.solana = [ + "base58", + "pynacl", +] +optional-dependencies.substrate = [ + "py-sr25519-bindings", + "substrate-interface<1.8", +] +optional-dependencies.tezos = [ + "pytezos-crypto==3.13.4.1", +] +urls.Documentation = "https://aleph.im/" +urls.Homepage = "https://github.com/aleph-im/aleph-sdk-python" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.version] +source = "vcs" + +[tool.hatch.build.targets.wheel] +packages = [ + "src/aleph", + "pyproject.toml", + "README.md", + "LICENSE.txt", +] + +[tool.hatch.build.targets.sdist] +include = [ + "src/aleph", + "pyproject.toml", + "README.md", + "LICENSE.txt", +] + +[[tool.hatch.envs.all.matrix]] +python = [ "3.9", "3.10", "3.11" ] + +[tool.hatch.envs.testing] +python = "3.13" +features = [ + "cosmos", + "dns", + "ledger", + "nuls2", + "substrate", + "solana", + "tezos", + "encryption", +] +dependencies = [ + "pytest==8.0.1", + "pytest-cov==4.1.0", + "pytest-mock==3.12.0", + "pytest-asyncio==0.23.5", + "pytest-aiohttp==1.0.5", + "aioresponses==0.7.6", + "fastapi", + "httpx", + "secp256k1", +] +[tool.hatch.envs.testing.scripts] +test = "pytest {args:} ./src/ ./tests/ ./examples/" +test-cov = "pytest --cov {args:} ./src/ ./tests/ ./examples/" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[tool.hatch.envs.linting] +detached = true +dependencies = [ + "black==24.2.0", + "mypy==1.9.0", + "mypy-extensions==1.0.0", + "ruff==0.4.8", + "isort==5.13.2", + "pyproject-fmt==2.2.1", +] +[tool.hatch.envs.linting.scripts] +typing = "mypy --config-file=pyproject.toml {args:} ./src/ ./tests/ ./examples/" +style = [ + "ruff check {args:.} ./src/ ./tests/ ./examples/", + "black --check --diff {args:} ./src/ ./tests/ ./examples/", + "isort --check-only --profile black {args:} ./src/ ./tests/ ./examples/", + "pyproject-fmt --check pyproject.toml", +] +fmt = [ + "black {args:} ./src/ ./tests/ ./examples/", + "ruff check --fix {args:.} ./src/ ./tests/ ./examples/", + "isort --profile black {args:} ./src/ ./tests/ ./examples/", + "pyproject-fmt pyproject.toml", + "style", +] +all = [ + "style", + "typing", +] + +[tool.isort] +profile = "black" + +[tool.pytest.ini_options] +minversion = "6.0" +pythonpath = [ "src" ] +addopts = "-vv -m \"not ledger_hardware\"" +norecursedirs = [ "*.egg", "dist", "build", ".tox", ".venv", "*/site-packages/*" ] +testpaths = [ "tests/unit" ] +markers = { ledger_hardware = "marks tests as requiring ledger hardware" } + +[tool.coverage.run] +branch = true +parallel = true +source = [ + "src/", +] +omit = [ + "*/site-packages/*", +] + +[tool.coverage.paths] +source = [ + "src/", +] +omit = [ + "*/site-packages/*", +] + +[tool.coverage.report] +show_missing = true +skip_empty = true +exclude_lines = [ + # Have to re-enable the standard pragma + "pragma: no cover", + + # Don't complain about missing debug-only code: + "def __repr__", + "if self\\.debug", + + # Don't complain if tests don't hit defensive assertion code: + "raise AssertionError", + "raise NotImplementedError", + + # Don't complain if non-runnable code isn't run: + "if 0:", + "if __name__ == .__main__.:", + + # Don't complain about ineffective code: + "pass", +] + +[tool.mypy] +python_version = 3.9 +mypy_path = "src" +exclude = [ + "conftest.py", +] +show_column_numbers = true +check_untyped_defs = true + +# Import discovery +# Install types for third-party library stubs (e.g. from typeshed repository) +install_types = true +non_interactive = true +# Suppresses error messages about imports that cannot be resolved (no py.typed file, no stub file, etc). +ignore_missing_imports = true +# Don't follow imports +follow_imports = "silent" + +# Miscellaneous strictness flags +# Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition. +allow_redefinition = true diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index de505203..00000000 --- a/setup.cfg +++ /dev/null @@ -1,178 +0,0 @@ -# This file is used to configure your project. -# Read more about the various options under: -# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files - -[metadata] -name = aleph-sdk-python -description = Lightweight Python Client library for the Aleph.im network -author = Aleph.im Team -author_email = hello@aleph.im -license = mit -long_description = file: README.md -long_description_content_type = text/markdown; charset=UTF-8 -url = https://github.com/aleph-im/aleph-sdk-python -project_urls = - Documentation = https://aleph.im/ -# Change if running only on Windows, Mac or Linux (comma-separated) -platforms = any -# Add here all kinds of additional classifiers as defined under -# https://pypi.python.org/pypi?%3Aaction=list_classifiers -classifiers = - Development Status :: 4 - Beta - Programming Language :: Python :: 3 - -[options] -zip_safe = False -packages = find: -include_package_data = True -package_dir = - =src -# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD! -setup_requires = pyscaffold>=3.2a0,<3.3a0 -# Add here dependencies of your project (semicolon/line-separated), e.g. -install_requires = - coincurve; python_version<"3.11" - coincurve>=17.0.0; python_version>="3.11" # Technically, this should be >=18.0.0 but there is a conflict with eciespy - aiohttp>=3.8.3 - eciespy; python_version<"3.11" - eciespy>=0.3.13; python_version>="3.11" - typing_extensions - typer - aleph-message~=0.4.3 - eth_account>=0.4.0 - # Required to fix a dependency issue with parsimonious and Python3.11 - eth_abi==4.0.0b2; python_version>="3.11" - python-magic -# The usage of test_requires is discouraged, see `Dependency Management` docs -# tests_require = pytest; pytest-cov -# Require a specific Python version, e.g. Python 2.7 or >= 3.4 -# python_requires = >=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.* - -[options.packages.find] -where = src -exclude = - tests - -[options.extras_require] -# Add here additional requirements for extra features, to install with: -# `pip install aleph-sdk-python[PDF]` like: -# PDF = ReportLab; RXP -# Add here test requirements (semicolon/line-separated) -testing = - aiomqtt<=0.1.3 - psutil - pytest - pytest-cov - pytest-asyncio - pytest-mock - mypy - secp256k1 - pynacl - base58 - fastapi - # httpx is required in tests by fastapi.testclient - httpx - requests - aleph-pytezos==0.1.1 - types-certifi - types-setuptools - black - isort - flake8 - substrate-interface - py-sr25519-bindings - ledgereth==0.9.0 - aiodns -dns = - aiodns -mqtt = - aiomqtt<=0.1.3 - certifi - Click -nuls2 = - aleph-nuls2 -ethereum = - eth_account>=0.4.0 - # Required to fix a dependency issue with parsimonious and Python3.11 - eth_abi==4.0.0b2; python_version>="3.11" -polkadot = - substrate-interface - py-sr25519-bindings -cosmos = - cosmospy -solana = - pynacl - base58 -tezos = - pynacl - aleph-pytezos==0.1.1 -ledger = - ledgereth==0.9.0 -docs = - sphinxcontrib-plantuml - -[options.entry_points] -# Add here console scripts like: -# For example: -# console_scripts = -# fibonacci = aleph.sdk.skeleton:run -# And any other entry points, for example: -# pyscaffold.cli = -# awesome = pyscaffoldext.awesome.extension:AwesomeExtension - -[test] -# py.test options when running `python setup.py test` -# addopts = --verbose -extras = True - -[tool:pytest] -# Options for py.test: -# Specify command line options as you would do when invoking py.test directly. -# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml -# in order to write a coverage file that can be read by Jenkins. -addopts = - --cov aleph.sdk --cov-report term-missing - --verbose - -m "not ledger_hardware" -norecursedirs = - dist - build - .tox -testpaths = tests -markers = - "ledger_hardware: marks tests as requiring ledger hardware" -[aliases] -dists = bdist_wheel - -[bdist_wheel] -# Use this option if your package is pure-python -universal = 0 - -[build_sphinx] -source_dir = docs -build_dir = build/sphinx - -[devpi:upload] -# Options for the devpi: PyPI server and packaging tool -# VCS export must be deactivated since we are using setuptools-scm -no-vcs = 1 -formats = bdist_wheel - -[flake8] -# Some sane defaults for the code style checker flake8 -exclude = - .tox - build - dist - .eggs - docs/conf.py -ignore = E501 W291 W503 E203 E704 - -[isort] -profile = black - -[pyscaffold] -# PyScaffold's parameters when the project was created. -# This will be used when updating. Do not change! -version = 3.2.1 -package = aleph.sdk diff --git a/setup.py b/setup.py deleted file mode 100644 index 9b29e6b3..00000000 --- a/setup.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- -""" - Setup file for aleph.sdk - Use setup.cfg to configure your project. - - This file was generated with PyScaffold 3.2.1. - PyScaffold helps you to put up the scaffold of your new Python project. - Learn more under: https://pyscaffold.org/ -""" -import sys - -from pkg_resources import VersionConflict, require -from setuptools import setup - -try: - require("setuptools>=38.3") -except VersionConflict: - print("Error: version of setuptools is too old (<38.3)!") - sys.exit(1) - - -if __name__ == "__main__": - setup(use_pyscaffold=True) diff --git a/src/aleph/py.typed b/src/aleph/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/__init__.py b/src/aleph/sdk/__init__.py index a3ecc693..358ddd96 100644 --- a/src/aleph/sdk/__init__.py +++ b/src/aleph/sdk/__init__.py @@ -1,17 +1,14 @@ -from pkg_resources import DistributionNotFound, get_distribution +from importlib.metadata import PackageNotFoundError, version from aleph.sdk.client import AlephHttpClient, AuthenticatedAlephHttpClient try: # Change here if project is renamed and does not equal the package name - dist_name = "aleph-sdk-python" - __version__ = get_distribution(dist_name).version -except DistributionNotFound: + __version__ = version("aleph-sdk-python") +except PackageNotFoundError: __version__ = "unknown" -finally: - del get_distribution, DistributionNotFound -__all__ = ["AlephHttpClient", "AuthenticatedAlephHttpClient"] +__all__ = ["__version__", "AlephHttpClient", "AuthenticatedAlephHttpClient"] def __getattr__(name): diff --git a/src/aleph/sdk/account.py b/src/aleph/sdk/account.py index 6ec08c83..83b6e1ea 100644 --- a/src/aleph/sdk/account.py +++ b/src/aleph/sdk/account.py @@ -1,63 +1,213 @@ -import asyncio import logging from pathlib import Path -from typing import Optional, Type, TypeVar +from typing import Dict, Literal, Optional, Type, TypeVar, Union, overload + +from aleph_message.models import Chain +from ledgereth.exceptions import LedgerError +from typing_extensions import TypeAlias from aleph.sdk.chains.common import get_fallback_private_key from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.chains.remote import RemoteAccount -from aleph.sdk.conf import settings -from aleph.sdk.types import AccountFromPrivateKey +from aleph.sdk.chains.evm import EVMAccount +from aleph.sdk.chains.solana import SOLAccount +from aleph.sdk.chains.substrate import DOTAccount +from aleph.sdk.chains.svm import SVMAccount +from aleph.sdk.conf import AccountType, load_main_configuration, settings +from aleph.sdk.evm_utils import get_chains_with_super_token +from aleph.sdk.types import AccountFromPrivateKey, HardwareAccount +from aleph.sdk.wallets.ledger import LedgerETHAccount logger = logging.getLogger(__name__) T = TypeVar("T", bound=AccountFromPrivateKey) +AccountTypes: TypeAlias = Union["AccountFromPrivateKey", "HardwareAccount"] + +chain_account_map: Dict[Chain, Type[T]] = { # type: ignore + Chain.ARBITRUM: EVMAccount, + Chain.AURORA: EVMAccount, + Chain.AVAX: ETHAccount, + Chain.BASE: ETHAccount, + Chain.BLAST: EVMAccount, + Chain.BOB: EVMAccount, + Chain.BSC: EVMAccount, + Chain.CYBER: EVMAccount, + Chain.DOT: DOTAccount, + Chain.ECLIPSE: SVMAccount, + Chain.ETH: ETHAccount, + Chain.FRAXTAL: EVMAccount, + Chain.INK: EVMAccount, + Chain.LINEA: EVMAccount, + Chain.LISK: EVMAccount, + Chain.METIS: EVMAccount, + Chain.MODE: EVMAccount, + Chain.NEO: EVMAccount, + Chain.OPTIMISM: EVMAccount, + Chain.POL: EVMAccount, + Chain.SOL: SOLAccount, + Chain.SOMNIA: EVMAccount, + Chain.SONIC: EVMAccount, + Chain.UNICHAIN: EVMAccount, + Chain.WORLDCHAIN: EVMAccount, + Chain.ZORA: EVMAccount, +} + +def load_chain_account_type(chain: Chain) -> Type[AccountFromPrivateKey]: + return chain_account_map.get(chain) or ETHAccount # type: ignore -def account_from_hex_string(private_key_str: str, account_type: Type[T]) -> T: + +def account_from_hex_string( + private_key_str: str, + account_type: Optional[Type[AccountFromPrivateKey]], + chain: Optional[Chain] = None, +) -> AccountFromPrivateKey: if private_key_str.startswith("0x"): private_key_str = private_key_str[2:] - return account_type(bytes.fromhex(private_key_str)) + if not chain: + chain = settings.DEFAULT_CHAIN + if not account_type: + account_type = load_chain_account_type(chain) # type: ignore + account = account_type( + bytes.fromhex(private_key_str), + **({"chain": chain} if type(account_type) in [ETHAccount, EVMAccount] else {}), + ) # type: ignore -def account_from_file(private_key_path: Path, account_type: Type[T]) -> T: + if chain in get_chains_with_super_token(): + account.switch_chain(chain) + return account # type: ignore + + +def account_from_file( + private_key_path: Path, + account_type: Optional[Type[AccountFromPrivateKey]], + chain: Optional[Chain] = None, +) -> AccountFromPrivateKey: private_key = private_key_path.read_bytes() - return account_type(private_key) + + if not chain: + chain = settings.DEFAULT_CHAIN + if not account_type: + account_type = load_chain_account_type(chain) # type: ignore + account = account_type( + private_key, + **({"chain": chain} if type(account_type) in [ETHAccount, EVMAccount] else {}), + ) # type: ignore + + if chain in get_chains_with_super_token(): + account.switch_chain(chain) + return account + + +@overload +def _load_account( + private_key_str: str, + private_key_path: None = None, + account_type: Type[AccountFromPrivateKey] = ..., + chain: Optional[Chain] = None, +) -> AccountFromPrivateKey: ... + + +@overload +def _load_account( + private_key_str: Literal[None], + private_key_path: Path, + account_type: Type[AccountFromPrivateKey] = ..., + chain: Optional[Chain] = None, +) -> AccountFromPrivateKey: ... + + +@overload +def _load_account( + private_key_str: Literal[None], + private_key_path: Literal[None], + account_type: Type[HardwareAccount], + chain: Optional[Chain] = None, +) -> HardwareAccount: ... +@overload def _load_account( private_key_str: Optional[str] = None, private_key_path: Optional[Path] = None, - account_type: Type[AccountFromPrivateKey] = ETHAccount, -) -> AccountFromPrivateKey: - """Load private key from a string or a file. + account_type: Optional[Type[AccountTypes]] = None, + chain: Optional[Chain] = None, +) -> AccountTypes: ... + - Only keys that accounts that can be initiated from a +def _load_account( + private_key_str: Optional[str] = None, + private_key_path: Optional[Path] = None, + account_type: Optional[Type[AccountTypes]] = None, + chain: Optional[Chain] = None, +) -> AccountTypes: + """Load an account from a private key string or file, or from the configuration file. + + This function can return different types of accounts based on the input: + - AccountFromPrivateKey: When a private key is provided (string or file) + - HardwareAccount: When config has AccountType.HARDWARE and a Ledger device is connected + + The function will attempt to load an account in the following order: + 1. From provided private key string + 2. From provided private key file + 3. From Ledger device (if config.type is HARDWARE) + 4. Generate a fallback private key """ - assert not ( - private_key_str and private_key_path - ), "Private key should be a string or a filepath, not both." + config = load_main_configuration(settings.CONFIG_FILE) + default_chain = settings.DEFAULT_CHAIN - if private_key_str: - logger.debug("Using account from string") - return account_from_hex_string(private_key_str, account_type) - elif private_key_path and private_key_path.is_file(): - logger.debug("Using account from file") - return account_from_file(private_key_path, account_type) - elif settings.REMOTE_CRYPTO_HOST: - logger.debug("Using remote account") - loop = asyncio.get_event_loop() - return loop.run_until_complete( - RemoteAccount.from_crypto_host( - host=settings.REMOTE_CRYPTO_HOST, - unix_socket=settings.REMOTE_CRYPTO_UNIX_SOCKET, + if not chain: + if config and hasattr(config, "chain"): + chain = config.chain + logger.debug( + f"Detected {config.chain} account for path {settings.CONFIG_FILE}" ) + else: + chain = default_chain + logger.warning( + f"No main configuration found on path {settings.CONFIG_FILE}, defaulting to {chain}" + ) + + # Loads configuration if no account_type is specified + if not account_type: + account_type = load_chain_account_type(chain) + logger.debug( + f"No account type specified defaulting to {account_type and account_type.__name__}" ) - else: - new_private_key = get_fallback_private_key() - account = account_type(private_key=new_private_key) - logger.info( - f"Generated fallback private key with address {account.get_address()}" - ) - return account + + # Loads private key from a string + if private_key_str: + return account_from_hex_string(private_key_str, None, chain) + + # Loads private key from a file + elif private_key_path and private_key_path.is_file(): + return account_from_file(private_key_path, account_type, chain) # type: ignore + elif config and config.address and config.type == AccountType.HARDWARE: + logger.debug("Using ledger account") + try: + ledger_account = None + if config.derivation_path: + ledger_account = LedgerETHAccount.from_path(config.derivation_path) + else: + ledger_account = LedgerETHAccount.from_address(config.address) + + if ledger_account: + # Connect provider to the chain + # Only valid for EVM chain sign we sign TX using device + # and then use Superfluid logic to publish it to BASE / AVAX + if chain: + ledger_account.connect_chain(chain) + return ledger_account + except LedgerError as e: + logger.warning(f"Ledger Error : {e.message}") + raise e + except OSError as e: + logger.warning("Please ensure Udev rules are set to use Ledger") + raise e + + # Fallback: config.path if set, else generate a new private key + new_private_key = get_fallback_private_key() + account = account_from_hex_string(bytes.hex(new_private_key), None, chain) + logger.info(f"Generated fallback private key with address {account.get_address()}") + return account diff --git a/src/aleph/sdk/chains/common.py b/src/aleph/sdk/chains/common.py index 3c7e634e..d2714d62 100644 --- a/src/aleph/sdk/chains/common.py +++ b/src/aleph/sdk/chains/common.py @@ -4,7 +4,7 @@ from typing import Dict, Optional from coincurve.keys import PrivateKey -from ecies import decrypt, encrypt +from typing_extensions import deprecated from aleph.sdk.conf import settings from aleph.sdk.utils import enum_as_str @@ -73,6 +73,7 @@ async def sign_message(self, message: Dict) -> Dict: message = self._setup_sender(message) signature = await self.sign_raw(get_verification_buffer(message)) message["signature"] = signature.hex() + return message @abstractmethod @@ -100,6 +101,7 @@ def get_public_key(self) -> str: """ raise NotImplementedError + @deprecated("This method will be moved to its own module `aleph.sdk.encryption`") async def encrypt(self, content: bytes) -> bytes: """ Encrypts a message using the account's public key. @@ -108,12 +110,19 @@ async def encrypt(self, content: bytes) -> bytes: Returns: bytes: Encrypted content as bytes """ + try: + from ecies import encrypt + except ImportError: + raise ImportError( + "Install `eciespy` or `aleph-sdk-python[encryption]` to use this method" + ) if self.CURVE == "secp256k1": value: bytes = encrypt(self.get_public_key(), content) return value else: raise NotImplementedError + @deprecated("This method will be moved to its own module `aleph.sdk.encryption`") async def decrypt(self, content: bytes) -> bytes: """ Decrypts a message using the account's private key. @@ -122,6 +131,12 @@ async def decrypt(self, content: bytes) -> bytes: Returns: bytes: Decrypted content as bytes """ + try: + from ecies import decrypt + except ImportError: + raise ImportError( + "Install `eciespy` or `aleph-sdk-python[encryption]` to use this method" + ) if self.CURVE == "secp256k1": value: bytes = decrypt(self.private_key, content) return value @@ -156,10 +171,3 @@ def get_fallback_private_key(path: Optional[Path] = None) -> bytes: if not default_key_path.exists(): default_key_path.symlink_to(path) return private_key - - -def bytes_from_hex(hex_string: str) -> bytes: - if hex_string.startswith("0x"): - hex_string = hex_string[2:] - hex_string = bytes.fromhex(hex_string) - return hex_string diff --git a/src/aleph/sdk/chains/ethereum.py b/src/aleph/sdk/chains/ethereum.py index 124fbee7..22601897 100644 --- a/src/aleph/sdk/chains/ethereum.py +++ b/src/aleph/sdk/chains/ethereum.py @@ -1,28 +1,225 @@ +import asyncio +import base64 +from abc import abstractmethod +from decimal import Decimal from pathlib import Path -from typing import Optional, Union +from typing import Awaitable, Dict, Optional, Union -from eth_account import Account +from aleph_message.models import Chain +from eth_account import Account # type: ignore from eth_account.messages import encode_defunct from eth_account.signers.local import LocalAccount from eth_keys.exceptions import BadSignature as EthBadSignatureError +from superfluid import Web3FlowInfo +from web3 import Web3 +from web3.exceptions import ContractCustomError +from web3.middleware import ExtraDataToPOAMiddleware +from web3.types import TxParams, TxReceipt -from ..exceptions import BadSignatureError -from .common import ( - BaseAccount, - bytes_from_hex, - get_fallback_private_key, - get_public_key, +from aleph.sdk.exceptions import InsufficientFundsError +from aleph.sdk.types import TokenType + +from ..conf import settings +from ..connectors.superfluid import Superfluid +from ..evm_utils import ( + BALANCEOF_ABI, + MIN_ETH_BALANCE_WEI, + FlowUpdate, + from_wei_token, + get_chain_id, + get_chains_with_super_token, + get_rpc, + get_super_token_address, + get_token_address, ) +from ..exceptions import BadSignatureError +from ..utils import bytes_from_hex +from .common import BaseAccount, get_fallback_private_key, get_public_key + +class BaseEthAccount(BaseAccount): + """Base logic to interact with EVM blockchains""" -class ETHAccount(BaseAccount): CHAIN = "ETH" CURVE = "secp256k1" + + _provider: Optional[Web3] + chain: Optional[Chain] + chain_id: Optional[int] + rpc: Optional[str] + superfluid_connector: Optional[Superfluid] + + def __init__(self, chain: Optional[Chain] = None): + self.chain = chain + self.connect_chain(chain=chain) + + @abstractmethod + async def _sign_and_send_transaction(self, tx_params: TxParams) -> str: + """ + Sign and broadcast a transaction using the provided ETHAccount + @param tx_params - Transaction parameters + @returns - str - Transaction hash + """ + raise NotImplementedError + + def connect_chain(self, chain: Optional[Chain] = None): + self.chain = chain + if self.chain: + self.chain_id = get_chain_id(self.chain) + self.rpc = get_rpc(self.chain) + self._provider = Web3(Web3.HTTPProvider(self.rpc)) + if chain == Chain.BSC: + self._provider.middleware_onion.inject( + ExtraDataToPOAMiddleware, "geth_poa", layer=0 + ) + else: + self.chain_id = None + self.rpc = None + self._provider = None + + if chain in get_chains_with_super_token() and self._provider: + self.superfluid_connector = Superfluid(self) + else: + self.superfluid_connector = None + + def switch_chain(self, chain: Optional[Chain] = None): + self.connect_chain(chain=chain) + + def can_transact(self, tx: TxParams, block=True) -> bool: + balance_wei = self.get_eth_balance() + try: + assert self._provider is not None + + estimated_gas = self._provider.eth.estimate_gas(tx) + + gas_price = tx.get("gasPrice", self._provider.eth.gas_price) + + if "maxFeePerGas" in tx: + max_fee = tx["maxFeePerGas"] + total_fee_wei = estimated_gas * max_fee + else: + total_fee_wei = estimated_gas * gas_price + + total_fee_wei = int(total_fee_wei * 1.2) + + except ContractCustomError: + total_fee_wei = MIN_ETH_BALANCE_WEI # Fallback if estimation fails + + required_fee_wei = total_fee_wei + (tx.get("value", 0)) + + valid = balance_wei > required_fee_wei if self.chain else False + if not valid and block: + raise InsufficientFundsError( + token_type=TokenType.GAS, + required_funds=float(from_wei_token(required_fee_wei)), + available_funds=float(from_wei_token(balance_wei)), + ) + return valid + + def get_eth_balance(self) -> Decimal: + if not self._provider: + raise ValueError( + "Provider not set. Please configure a provider before checking balance." + ) + + return Decimal(self._provider.eth.get_balance(self.get_address())) + + def get_token_balance(self) -> Decimal: + if self.chain and self._provider: + contact_address = get_token_address(self.chain) + if contact_address: + contract = self._provider.eth.contract( + address=contact_address, abi=BALANCEOF_ABI + ) + return Decimal(contract.functions.balanceOf(self.get_address()).call()) + return Decimal(0) + + def get_super_token_balance(self) -> Decimal: + if self.chain and self._provider: + contact_address = get_super_token_address(self.chain) + if contact_address: + contract = self._provider.eth.contract( + address=contact_address, abi=BALANCEOF_ABI + ) + return Decimal(contract.functions.balanceOf(self.get_address()).call()) + return Decimal(0) + + def can_start_flow(self, flow: Decimal) -> bool: + """Check if the account has enough funds to start a Superfluid flow of the given size.""" + if not self.superfluid_connector: + raise ValueError("Superfluid connector is required to check a flow") + return self.superfluid_connector.can_start_flow(flow) + + def create_flow(self, receiver: str, flow: Decimal) -> Awaitable[str]: + """Creat a Superfluid flow between this account and the receiver address.""" + if not self.superfluid_connector: + raise ValueError("Superfluid connector is required to create a flow") + return self.superfluid_connector.create_flow(receiver=receiver, flow=flow) + + def get_flow(self, receiver: str) -> Awaitable[Web3FlowInfo]: + """Get the Superfluid flow between this account and the receiver address.""" + if not self.superfluid_connector: + raise ValueError("Superfluid connector is required to get a flow") + return self.superfluid_connector.get_flow( + sender=self.get_address(), receiver=receiver + ) + + def update_flow(self, receiver: str, flow: Decimal) -> Awaitable[str]: + """Update the Superfluid flow between this account and the receiver address.""" + if not self.superfluid_connector: + raise ValueError("Superfluid connector is required to update a flow") + return self.superfluid_connector.update_flow(receiver=receiver, flow=flow) + + def delete_flow(self, receiver: str) -> Awaitable[str]: + """Delete the Superfluid flow between this account and the receiver address.""" + if not self.superfluid_connector: + raise ValueError("Superfluid connector is required to delete a flow") + return self.superfluid_connector.delete_flow(receiver=receiver) + + def manage_flow( + self, + receiver: str, + flow: Decimal, + update_type: FlowUpdate, + ) -> Awaitable[Optional[str]]: + """Manage the Superfluid flow between this account and the receiver address.""" + if not self.superfluid_connector: + raise ValueError("Superfluid connector is required to manage a flow") + return self.superfluid_connector.manage_flow( + receiver=receiver, flow=flow, update_type=update_type + ) + + +class ETHAccount(BaseEthAccount): + """Interact with an Ethereum address or key pair on EVM blockchains""" + _account: LocalAccount - def __init__(self, private_key: bytes): + def __init__( + self, + private_key: bytes, + chain: Optional[Chain] = None, + ): self.private_key = private_key self._account = Account.from_key(self.private_key) + super().__init__(chain=chain) + + @staticmethod + def from_mnemonic(mnemonic: str, chain: Optional[Chain] = None) -> "ETHAccount": + Account.enable_unaudited_hdwallet_features() + return ETHAccount( + private_key=Account.from_mnemonic(mnemonic=mnemonic).key, chain=chain + ) + + def export_private_key(self) -> str: + """Export the private key using standard format.""" + return f"0x{base64.b16encode(self.private_key).decode().lower()}" + + def get_address(self) -> str: + return self._account.address + + def get_public_key(self) -> str: + return "0x" + get_public_key(private_key=self._account.key).hex() async def sign_raw(self, buffer: bytes) -> bytes: """Sign a raw buffer.""" @@ -30,20 +227,51 @@ async def sign_raw(self, buffer: bytes) -> bytes: sig = self._account.sign_message(msghash) return sig["signature"] - def get_address(self) -> str: - return self._account.address + async def sign_message(self, message: Dict) -> Dict: + """ + Returns a signed message from an aleph Cloud message. + Args: + message: Message to sign + Returns: + Dict: Signed message + """ + signed_message = await super().sign_message(message) - def get_public_key(self) -> str: - return "0x" + get_public_key(private_key=self._account.key).hex() + # Apply that fix as seems that sometimes the .hex() method doesn't add the 0x str at the beginning + if not str(signed_message["signature"]).startswith("0x"): + signed_message["signature"] = "0x" + signed_message["signature"] - @staticmethod - def from_mnemonic(mnemonic: str) -> "ETHAccount": - Account.enable_unaudited_hdwallet_features() - return ETHAccount(private_key=Account.from_mnemonic(mnemonic=mnemonic).key) + return signed_message + + async def _sign_and_send_transaction(self, tx_params: TxParams) -> str: + """ + Sign and broadcast a transaction using the provided ETHAccount + @param tx_params - Transaction parameters + @returns - str - Transaction hash + """ + + def sign_and_send() -> TxReceipt: + if self._provider is None: + raise ValueError("Provider not connected") + signed_tx = self._provider.eth.account.sign_transaction( + tx_params, self._account.key + ) + + tx_hash = self._provider.eth.send_raw_transaction(signed_tx.raw_transaction) + tx_receipt = self._provider.eth.wait_for_transaction_receipt( + tx_hash, settings.TX_TIMEOUT + ) + return tx_receipt + + loop = asyncio.get_running_loop() + tx_receipt = await loop.run_in_executor(None, sign_and_send) + return tx_receipt["transactionHash"].hex() -def get_fallback_account(path: Optional[Path] = None) -> ETHAccount: - return ETHAccount(private_key=get_fallback_private_key(path=path)) +def get_fallback_account( + path: Optional[Path] = None, chain: Optional[Chain] = None +) -> ETHAccount: + return ETHAccount(private_key=get_fallback_private_key(path=path), chain=chain) def verify_signature( diff --git a/src/aleph/sdk/chains/evm.py b/src/aleph/sdk/chains/evm.py new file mode 100644 index 00000000..a5eeed84 --- /dev/null +++ b/src/aleph/sdk/chains/evm.py @@ -0,0 +1,57 @@ +from decimal import Decimal +from pathlib import Path +from typing import Awaitable, Optional + +from aleph_message.models import Chain +from eth_account import Account # type: ignore + +from ..evm_utils import FlowUpdate +from .common import get_fallback_private_key +from .ethereum import ETHAccount + + +class EVMAccount(ETHAccount): + def __init__(self, private_key: bytes, chain: Optional[Chain] = None): + super().__init__(private_key, chain) + # Decide if we have to send also the specified chain value or always use ETH + # if chain: + # self.CHAIN = chain + + @staticmethod + def from_mnemonic(mnemonic: str, chain: Optional[Chain] = None) -> "EVMAccount": + Account.enable_unaudited_hdwallet_features() + return EVMAccount( + private_key=Account.from_mnemonic(mnemonic=mnemonic).key, chain=chain + ) + + def get_token_balance(self) -> Decimal: + raise ValueError(f"Token not implemented for this chain {self.CHAIN}") + + def get_super_token_balance(self) -> Decimal: + raise ValueError(f"Super token not implemented for this chain {self.CHAIN}") + + def can_start_flow(self, flow: Decimal) -> bool: + raise ValueError(f"Flow checking not implemented for this chain {self.CHAIN}") + + def create_flow(self, receiver: str, flow: Decimal) -> Awaitable[str]: + raise ValueError(f"Flow creation not implemented for this chain {self.CHAIN}") + + def get_flow(self, receiver: str): + raise ValueError(f"Get flow not implemented for this chain {self.CHAIN}") + + def update_flow(self, receiver: str, flow: Decimal) -> Awaitable[str]: + raise ValueError(f"Flow update not implemented for this chain {self.CHAIN}") + + def delete_flow(self, receiver: str) -> Awaitable[str]: + raise ValueError(f"Flow deletion not implemented for this chain {self.CHAIN}") + + def manage_flow( + self, receiver: str, flow: Decimal, update_type: FlowUpdate + ) -> Awaitable[Optional[str]]: + raise ValueError(f"Flow management not implemented for this chain {self.CHAIN}") + + +def get_fallback_account( + path: Optional[Path] = None, chain: Optional[Chain] = None +) -> ETHAccount: + return ETHAccount(private_key=get_fallback_private_key(path=path), chain=chain) diff --git a/src/aleph/sdk/chains/sol.py b/src/aleph/sdk/chains/sol.py index ff870a4d..b8e85962 100644 --- a/src/aleph/sdk/chains/sol.py +++ b/src/aleph/sdk/chains/sol.py @@ -1,93 +1,9 @@ -import json -from pathlib import Path -from typing import Dict, Optional, Union +import warnings -import base58 -from nacl.exceptions import BadSignatureError as NaclBadSignatureError -from nacl.public import PrivateKey, SealedBox -from nacl.signing import SigningKey, VerifyKey +from aleph.sdk.chains.solana import * # noqa -from ..exceptions import BadSignatureError -from .common import BaseAccount, get_fallback_private_key, get_verification_buffer - - -def encode(item): - return base58.b58encode(bytes(item)).decode("ascii") - - -class SOLAccount(BaseAccount): - CHAIN = "SOL" - CURVE = "curve25519" - _signing_key: SigningKey - _private_key: PrivateKey - - def __init__(self, private_key: bytes): - self.private_key = private_key - self._signing_key = SigningKey(self.private_key) - self._private_key = self._signing_key.to_curve25519_private_key() - - async def sign_message(self, message: Dict) -> Dict: - """Sign a message inplace.""" - message = self._setup_sender(message) - verif = get_verification_buffer(message) - signature = await self.sign_raw(verif) - sig = { - "publicKey": self.get_address(), - "signature": encode(signature), - } - message["signature"] = json.dumps(sig) - return message - - async def sign_raw(self, buffer: bytes) -> bytes: - """Sign a raw buffer.""" - sig = self._signing_key.sign(buffer) - return sig.signature - - def get_address(self) -> str: - return encode(self._signing_key.verify_key) - - def get_public_key(self) -> str: - return bytes(self._signing_key.verify_key.to_curve25519_public_key()).hex() - - async def encrypt(self, content) -> bytes: - value: bytes = bytes(SealedBox(self._private_key.public_key).encrypt(content)) - return value - - async def decrypt(self, content) -> bytes: - value: bytes = SealedBox(self._private_key).decrypt(content) - return value - - -def get_fallback_account(path: Optional[Path] = None) -> SOLAccount: - return SOLAccount(private_key=get_fallback_private_key(path=path)) - - -def generate_key() -> bytes: - privkey = bytes(SigningKey.generate()) - return privkey - - -def verify_signature( - signature: Union[bytes, str], - public_key: Union[bytes, str], - message: Union[bytes, str], -): - """ - Verifies a signature. - Args: - signature: The signature to verify. Can be a base58 encoded string or bytes. - public_key: The public key to use for verification. Can be a base58 encoded string or bytes. - message: The message to verify. Can be an utf-8 string or bytes. - Raises: - BadSignatureError: If the signature is invalid. - """ - if isinstance(signature, str): - signature = base58.b58decode(signature) - if isinstance(message, str): - message = message.encode("utf-8") - if isinstance(public_key, str): - public_key = base58.b58decode(public_key) - try: - VerifyKey(public_key).verify(message, signature) - except NaclBadSignatureError as e: - raise BadSignatureError from e +warnings.warn( + "aleph.sdk.chains.sol is deprecated, use aleph.sdk.chains.solana instead", + DeprecationWarning, + stacklevel=1, +) diff --git a/src/aleph/sdk/chains/solana.py b/src/aleph/sdk/chains/solana.py new file mode 100644 index 00000000..920ca8a0 --- /dev/null +++ b/src/aleph/sdk/chains/solana.py @@ -0,0 +1,187 @@ +import json +from pathlib import Path +from typing import Dict, List, Optional, Union + +import base58 +from nacl.exceptions import BadSignatureError as NaclBadSignatureError +from nacl.public import PrivateKey, SealedBox +from nacl.signing import SigningKey, VerifyKey + +from ..exceptions import BadSignatureError +from .common import BaseAccount, get_fallback_private_key, get_verification_buffer + + +def encode(item): + return base58.b58encode(bytes(item)).decode("ascii") + + +class SOLAccount(BaseAccount): + CHAIN = "SOL" + CURVE = "curve25519" + _signing_key: SigningKey + _private_key: PrivateKey + + def __init__(self, private_key: bytes): + self.private_key = parse_private_key(private_key_from_bytes(private_key)) + self._signing_key = SigningKey(self.private_key) + self._private_key = self._signing_key.to_curve25519_private_key() + + async def sign_message(self, message: Dict) -> Dict: + """Sign a message inplace.""" + message = self._setup_sender(message) + verif = get_verification_buffer(message) + signature = await self.sign_raw(verif) + sig = { + "publicKey": self.get_address(), + "signature": encode(signature), + } + message["signature"] = json.dumps(sig) + return message + + async def sign_raw(self, buffer: bytes) -> bytes: + """Sign a raw buffer.""" + sig = self._signing_key.sign(buffer) + return sig.signature + + def export_private_key(self) -> str: + """Export the private key using Phantom format.""" + return base58.b58encode( + self.private_key + self._signing_key.verify_key.encode() + ).decode() + + def get_address(self) -> str: + return encode(self._signing_key.verify_key) + + def get_public_key(self) -> str: + return bytes(self._signing_key.verify_key.to_curve25519_public_key()).hex() + + async def encrypt(self, content) -> bytes: + value: bytes = bytes(SealedBox(self._private_key.public_key).encrypt(content)) + return value + + async def decrypt(self, content) -> bytes: + value: bytes = SealedBox(self._private_key).decrypt(content) + return value + + +def get_fallback_account(path: Optional[Path] = None) -> SOLAccount: + return SOLAccount(private_key=get_fallback_private_key(path=path)) + + +def generate_key() -> bytes: + privkey = bytes(SigningKey.generate()) + return privkey + + +def verify_signature( + signature: Union[bytes, str], + public_key: Union[bytes, str], + message: Union[bytes, str], +): + """ + Verifies a signature. + Args: + signature: The signature to verify. Can be a base58 encoded string or bytes. + public_key: The public key to use for verification. Can be a base58 encoded string or bytes. + message: The message to verify. Can be an utf-8 string or bytes. + Raises: + BadSignatureError: If the signature is invalid.! + """ + if isinstance(signature, str): + signature = base58.b58decode(signature) + if isinstance(message, str): + message = message.encode("utf-8") + if isinstance(public_key, str): + public_key = base58.b58decode(public_key) + try: + VerifyKey(public_key).verify(message, signature) + except NaclBadSignatureError as e: + raise BadSignatureError from e + + +def private_key_from_bytes( + private_key_bytes: bytes, output_format: str = "base58" +) -> Union[str, List[int], bytes]: + """ + Convert a Solana private key in bytes back to different formats (base58 string, uint8 list, or raw bytes). + + - For base58 string: Encode the bytes into a base58 string. + - For uint8 list: Convert the bytes into a list of integers. + - For raw bytes: Return as-is. + + Args: + private_key_bytes (bytes): The private key in byte format. + output_format (str): The format to return ('base58', 'list', 'bytes'). + + Returns: + The private key in the requested format. + + Raises: + ValueError: If the output_format is not recognized or the private key length is invalid. + """ + if not isinstance(private_key_bytes, bytes): + raise ValueError("Expected the private key in bytes.") + + if len(private_key_bytes) != 32: + raise ValueError("Solana private key must be exactly 32 bytes long.") + + if output_format == "base58": + return base58.b58encode(private_key_bytes).decode("utf-8") + + elif output_format == "list": + return list(private_key_bytes) + + elif output_format == "bytes": + return private_key_bytes + + else: + raise ValueError("Invalid output format. Choose 'base58', 'list', or 'bytes'.") + + +def parse_private_key(private_key: Union[str, List[int], bytes]) -> bytes: + """ + Parse the private key which could be either: + - a base58-encoded string (which may contain both private and public key) + - a list of uint8 integers (which may contain both private and public key) + - a byte array (exactly 32 bytes) + + Returns: + bytes: The private key in byte format (32 bytes). + + Raises: + ValueError: If the private key format is invalid or the length is incorrect. + """ + # If the private key is already in byte format + if isinstance(private_key, bytes): + if len(private_key) != 32: + raise ValueError("The private key in bytes must be exactly 32 bytes long.") + return private_key + + # If the private key is a base58-encoded string + elif isinstance(private_key, str): + try: + decoded_key = base58.b58decode(private_key) + if len(decoded_key) not in [32, 64]: + raise ValueError( + "The base58 decoded private key must be either 32 or 64 bytes long." + ) + return decoded_key[:32] + except Exception as e: + raise ValueError(f"Invalid base58 encoded private key: {e}") + + # If the private key is a list of uint8 integers + elif isinstance(private_key, list): + if all(isinstance(i, int) and 0 <= i <= 255 for i in private_key): + byte_key = bytes(private_key) + if len(byte_key) < 32: + raise ValueError("The uint8 array must contain at least 32 elements.") + return byte_key[:32] # Take the first 32 bytes (private key) + else: + raise ValueError( + "Invalid uint8 array, must contain integers between 0 and 255." + ) + + else: + raise ValueError( + "Unsupported private key format. Must be a base58 string, bytes, or a list of uint8 integers." + ) diff --git a/src/aleph/sdk/chains/substrate.py b/src/aleph/sdk/chains/substrate.py index 13795568..f4d18a0d 100644 --- a/src/aleph/sdk/chains/substrate.py +++ b/src/aleph/sdk/chains/substrate.py @@ -9,7 +9,8 @@ from ..conf import settings from ..exceptions import BadSignatureError -from .common import BaseAccount, bytes_from_hex, get_verification_buffer +from ..utils import bytes_from_hex +from .common import BaseAccount, get_verification_buffer logger = logging.getLogger(__name__) diff --git a/src/aleph/sdk/chains/svm.py b/src/aleph/sdk/chains/svm.py new file mode 100644 index 00000000..80f433dd --- /dev/null +++ b/src/aleph/sdk/chains/svm.py @@ -0,0 +1,13 @@ +from typing import Optional + +from aleph_message.models import Chain + +from .solana import SOLAccount + + +class SVMAccount(SOLAccount): + def __init__(self, private_key: bytes, chain: Optional[Chain] = None): + super().__init__(private_key=private_key) + # Same as EVM ACCOUNT need to decided if we want to send the specified chain or always use SOL + if chain: + self.CHAIN = chain diff --git a/src/aleph/sdk/chains/tezos.py b/src/aleph/sdk/chains/tezos.py index cffa3e78..c4ee08ab 100644 --- a/src/aleph/sdk/chains/tezos.py +++ b/src/aleph/sdk/chains/tezos.py @@ -2,9 +2,9 @@ from pathlib import Path from typing import Dict, Optional, Union -from aleph_pytezos.crypto.key import Key from nacl.public import SealedBox from nacl.signing import SigningKey +from pytezos_crypto.key import Key from .common import BaseAccount, get_fallback_private_key, get_verification_buffer diff --git a/src/aleph/sdk/client/abstract.py b/src/aleph/sdk/client/abstract.py index 3335ad86..3daa198e 100644 --- a/src/aleph/sdk/client/abstract.py +++ b/src/aleph/sdk/client/abstract.py @@ -1,11 +1,13 @@ # An interface for all clients to implement. - +import json import logging +import time from abc import ABC, abstractmethod from pathlib import Path from typing import ( Any, AsyncIterable, + Coroutine, Dict, Iterable, List, @@ -18,24 +20,62 @@ from aleph_message.models import ( AlephMessage, - MessagesResponse, + ExecutableContent, + ItemHash, + ItemType, MessageType, Payment, PostMessage, + parse_message, +) +from aleph_message.models.execution.environment import ( + HostRequirements, + HypervisorType, + TrustedExecutionEnvironment, ) from aleph_message.models.execution.program import Encoding from aleph_message.status import MessageStatus +from typing_extensions import deprecated + +from aleph.sdk.conf import settings +from aleph.sdk.types import Account, Authorization, SecurityAggregateContent +from aleph.sdk.utils import extended_json_encoder from ..query.filters import MessageFilter, PostFilter -from ..query.responses import PostsResponse +from ..query.responses import MessagesResponse, PostsResponse, PriceResponse from ..types import GenericMessage, StorageEnum -from ..utils import Writable +from ..utils import Writable, compute_sha256 DEFAULT_PAGE_SIZE = 200 class AlephClient(ABC): + async def get_aggregate(self, address: str, key: str) -> Optional[Dict[str, Dict]]: + """ + Get a value from the aggregate store by owner address and item key. + Returns None if no aggregate was found. + + :param address: Address of the owner of the aggregate + :param key: Key of the aggregate + """ + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + + async def get_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None + ) -> Optional[Dict[str, Dict]]: + """ + Get key-value pairs from the aggregate store by owner address. + Returns None if no aggregate was found. + + :param address: Address of the owner of the aggregate + :param keys: Keys of the aggregates to fetch (Default: all items) + """ + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + @abstractmethod + @deprecated( + "This method is deprecated and will be removed in upcoming versions. Use get_aggregate instead." + ) async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: """ Fetch a value from the aggregate store by owner address and item key. @@ -46,6 +86,9 @@ async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: raise NotImplementedError("Did you mean to import `AlephHttpClient`?") @abstractmethod + @deprecated( + "This method is deprecated and will be removed in upcoming versions. Use get_aggregates instead." + ) async def fetch_aggregates( self, address: str, keys: Optional[Iterable[str]] = None ) -> Dict[str, Dict]: @@ -96,22 +139,33 @@ async def get_posts_iterator( ) page += 1 for post in resp.posts: - yield post + yield post # type: ignore @abstractmethod - async def download_file( - self, - file_hash: str, - ) -> bytes: + async def download_file(self, file_hash: str) -> bytes: """ Get a file from the storage engine as raw bytes. - Warning: Downloading large files can be slow and memory intensive. + Warning: Downloading large files can be slow and memory intensive. Use `download_file_to()` to download them directly to disk instead. :param file_hash: The hash of the file to retrieve. """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + @abstractmethod + async def download_file_to_path( + self, + file_hash: str, + path: Union[Path, str], + ) -> Path: + """ + Download a file from the storage engine to given path. + + :param file_hash: The hash of the file to retrieve. + :param path: The path to which the file should be saved. + """ + raise NotImplementedError() + async def download_file_ipfs( self, file_hash: str, @@ -217,8 +271,58 @@ def watch_messages( """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + @abstractmethod + def get_estimated_price( + self, + content: ExecutableContent, + ) -> Coroutine[Any, Any, PriceResponse]: + """ + Get Instance/Program content estimated price + + :param content: Instance or Program content + """ + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + + @abstractmethod + def get_program_price( + self, + item_hash: str, + ) -> Coroutine[Any, Any, PriceResponse]: + """ + Get Program message Price + + :param item_hash: item_hash of executable message + """ + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + + async def get_authorizations(self, address: str) -> list[Authorization]: + """ + Retrieves all authorizations for a specific address. + """ + # TODO: update this implementation to use `get_aggregate()` once + # https://github.com/aleph-im/aleph-sdk-python/pull/273 is merged. + # There's currently no way to detect a nonexistent aggregate in generic code just yet. + # fetch_aggregate() throws an implementation-specific ClientResponseError in case of 404. + import aiohttp + + try: + security_aggregate_dict = await self.fetch_aggregate( + address=address, key="security" + ) + except aiohttp.ClientResponseError as e: + if e.status == 404: + return [] + raise + + security_aggregate = SecurityAggregateContent.model_validate( + security_aggregate_dict + ) + return security_aggregate.authorizations + class AuthenticatedAlephClient(AlephClient): + account: Account + @abstractmethod async def create_post( self, @@ -226,7 +330,7 @@ async def create_post( post_type: str, ref: Optional[str] = None, address: Optional[str] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, inline: bool = True, storage_engine: StorageEnum = StorageEnum.storage, sync: bool = False, @@ -251,9 +355,9 @@ async def create_post( async def create_aggregate( self, key: str, - content: Mapping[str, Any], + content: dict[str, Any], address: Optional[str] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, inline: bool = True, sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: @@ -263,7 +367,7 @@ async def create_aggregate( :param key: Key to use to store the content :param content: Content to store :param address: Address to use to sign the message - :param channel: Channel to use (Default: "TEST") + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") :param inline: Whether to write content inside the message (Default: True) :param sync: If true, waits for the message to be processed by the API server (Default: False) """ @@ -282,8 +386,9 @@ async def create_store( ref: Optional[str] = None, storage_engine: StorageEnum = StorageEnum.storage, extra_fields: Optional[dict] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, sync: bool = False, + payment: Optional[Payment] = None, ) -> Tuple[AlephMessage, MessageStatus]: """ Create a STORE message to store a file on the aleph.im network. @@ -300,6 +405,7 @@ async def create_store( :param extra_fields: Extra fields to add to the STORE message (Default: None) :param channel: Channel to post the message to (Default: "TEST") :param sync: If true, waits for the message to be processed by the API server (Default: False) + :param payment: Payment method used to pay for storage (Default: hold on ETH) """ raise NotImplementedError( "Did you mean to import `AuthenticatedAlephHttpClient`?" @@ -311,22 +417,23 @@ async def create_program( program_ref: str, entrypoint: str, runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, + payment: Optional[Payment] = None, vcpus: Optional[int] = None, + memory: Optional[int] = None, timeout_seconds: Optional[float] = None, - persistent: bool = False, - allow_amend: bool = False, internet: bool = True, + allow_amend: bool = False, aleph_api: bool = True, encoding: Encoding = Encoding.zip, + persistent: bool = False, volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, - metadata: Optional[Mapping[str, Any]] = None, + environment_variables: Optional[dict[str, str]] = None, + subscriptions: Optional[List[dict]] = None, + sync: bool = False, + channel: Optional[str] = settings.DEFAULT_CHANNEL, + storage_engine: StorageEnum = StorageEnum.storage, ) -> Tuple[AlephMessage, MessageStatus]: """ Post a (create) PROGRAM message. @@ -334,22 +441,23 @@ async def create_program( :param program_ref: Reference to the program to run :param entrypoint: Entrypoint to run :param runtime: Runtime to use - :param environment_variables: Environment variables to pass to the program - :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") + :param metadata: Metadata to attach to the message :param address: Address to use (Default: account.get_address()) - :param sync: If true, waits for the message to be processed by the API server - :param memory: Memory in MB for the VM to be allocated (Default: 128) + :param payment: Payment method used to pay for the program (Default: None) :param vcpus: Number of vCPUs to allocate (Default: 1) + :param memory: Memory in MB for the VM to be allocated (Default: 128) :param timeout_seconds: Timeout in seconds (Default: 30.0) - :param persistent: Whether the program should be persistent or not (Default: False) - :param allow_amend: Whether the deployed VM image may be changed (Default: False) :param internet: Whether the VM should have internet connectivity. (Default: True) + :param allow_amend: Whether the deployed VM image may be changed (Default: False) :param aleph_api: Whether the VM needs access to Aleph messages API (Default: True) :param encoding: Encoding to use (Default: Encoding.zip) + :param persistent: Whether the program should be persistent or not (Default: False) :param volumes: Volumes to mount + :param environment_variables: Environment variables to pass to the program :param subscriptions: Patterns of aleph.im messages to forward to the program's event receiver - :param metadata: Metadata to attach to the message + :param sync: If true, waits for the message to be processed by the API server + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") + :param storage_engine: Storage engine to use (Default: "storage") """ raise NotImplementedError( "Did you mean to import `AuthenticatedAlephHttpClient`?" @@ -360,11 +468,10 @@ async def create_instance( self, rootfs: str, rootfs_size: int, - rootfs_name: str, payment: Optional[Payment] = None, - environment_variables: Optional[Mapping[str, str]] = None, + environment_variables: Optional[dict[str, str]] = None, storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, address: Optional[str] = None, sync: bool = False, memory: Optional[int] = None, @@ -373,34 +480,39 @@ async def create_instance( allow_amend: bool = False, internet: bool = True, aleph_api: bool = True, + hypervisor: Optional[HypervisorType] = None, + trusted_execution: Optional[TrustedExecutionEnvironment] = None, volumes: Optional[List[Mapping]] = None, volume_persistence: str = "host", ssh_keys: Optional[List[str]] = None, - metadata: Optional[Mapping[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, + requirements: Optional[HostRequirements] = None, ) -> Tuple[AlephMessage, MessageStatus]: """ - Post a (create) PROGRAM message. + Post a (create) INSTANCE message. :param rootfs: Root filesystem to use :param rootfs_size: Size of root filesystem - :param rootfs_name: Name of root filesystem :param payment: Payment method used to pay for the instance :param environment_variables: Environment variables to pass to the program :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") :param address: Address to use (Default: account.get_address()) :param sync: If true, waits for the message to be processed by the API server - :param memory: Memory in MB for the VM to be allocated (Default: 128) + :param memory: Memory in MB for the VM to be allocated (Default: 2048) :param vcpus: Number of vCPUs to allocate (Default: 1) :param timeout_seconds: Timeout in seconds (Default: 30.0) :param allow_amend: Whether the deployed VM image may be changed (Default: False) :param internet: Whether the VM should have internet connectivity. (Default: True) :param aleph_api: Whether the VM needs access to Aleph messages API (Default: True) + :param hypervisor: Whether the VM should use as Hypervisor, like QEmu or Firecracker (Default: Qemu) + :param trusted_execution: Whether the VM configuration (firmware and policy) to use for Confidential computing (Default: None) :param encoding: Encoding to use (Default: Encoding.zip) :param volumes: Volumes to mount :param volume_persistence: Where volumes are persisted, can be "host" or "store", meaning distributed across Aleph.im (Default: "host") :param ssh_keys: SSH keys to authorize access to the VM :param metadata: Metadata to attach to the message + :param requirements: CRN Requirements needed for the VM execution """ raise NotImplementedError( "Did you mean to import `AuthenticatedAlephHttpClient`?" @@ -409,10 +521,10 @@ async def create_instance( @abstractmethod async def forget( self, - hashes: List[str], + hashes: List[ItemHash], reason: Optional[str], storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, address: Optional[str] = None, sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: @@ -425,7 +537,7 @@ async def forget( :param hashes: Hashes of the messages to forget :param reason: Reason for forgetting the messages :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") :param address: Address to use (Default: account.get_address()) :param sync: If true, waits for the message to be processed by the API server (Default: False) """ @@ -433,12 +545,68 @@ async def forget( "Did you mean to import `AuthenticatedAlephHttpClient`?" ) + async def generate_signed_message( + self, + message_type: MessageType, + content: Dict[str, Any], + channel: Optional[str], + allow_inlining: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + ) -> AlephMessage: + """Generate a signed aleph.im message ready to be sent to the network. + + If the content is not inlined, it will be pushed to the storage engine via the API of a Core Channel Node. + + :param message_type: Type of the message (PostMessage, ...) + :param content: User-defined content of the message + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") + :param allow_inlining: Whether to allow inlining the content of the message (Default: True) + :param storage_engine: Storage engine to use (Default: "storage") + """ + + message_dict: Dict[str, Any] = { + "sender": self.account.get_address(), + "chain": self.account.CHAIN, + "type": message_type, + "content": content, + "time": time.time(), + "channel": channel, + } + + # Use the Pydantic encoder to serialize types like UUID, datetimes, etc. + item_content: str = json.dumps( + content, separators=(",", ":"), default=extended_json_encoder + ) + + if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE): + message_dict["item_content"] = item_content + message_dict["item_hash"] = compute_sha256(item_content) + message_dict["item_type"] = ItemType.inline + else: + if storage_engine == StorageEnum.ipfs: + message_dict["item_hash"] = await self.ipfs_push( + content=content, + ) + message_dict["item_type"] = ItemType.ipfs + else: # storage + assert storage_engine == StorageEnum.storage + message_dict["item_hash"] = await self.storage_push( + content=content, + ) + message_dict["item_type"] = ItemType.storage + + message_dict = await self.account.sign_message(message_dict) + return parse_message(message_dict) + + # Alias for backwards compatibility + _prepare_aleph_message = generate_signed_message + @abstractmethod async def submit( self, content: Dict[str, Any], message_type: MessageType, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, storage_engine: StorageEnum = StorageEnum.storage, allow_inlining: bool = True, sync: bool = False, @@ -450,7 +618,7 @@ async def submit( :param content: Content of the message :param message_type: Type of the message - :param channel: Channel to use (Default: "TEST") + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") :param storage_engine: Storage engine to use (Default: "storage") :param allow_inlining: Whether to allow inlining the content of the message (Default: True) :param sync: If true, waits for the message to be processed by the API server (Default: False) @@ -475,3 +643,35 @@ async def storage_push(self, content: Mapping) -> str: :param content: The dict-like content to upload """ raise NotImplementedError() + + async def update_all_authorizations(self, authorizations: list[Authorization]): + """ + Updates all authorizations for the current account. + Danger! This will replace all authorizations for the account. Use with care. + + :param authorizations: List of authorizations to set. These authorizations will replace the existing ones. + """ + security_aggregate = SecurityAggregateContent(authorizations=authorizations) + await self.create_aggregate( + key="security", content=security_aggregate.model_dump() + ) + + async def add_authorization(self, authorization: Authorization): + """ + Adds a specific authorization for the current account. + """ + authorizations = await self.get_authorizations(self.account.get_address()) + authorizations.append(authorization) + await self.update_all_authorizations(authorizations) + + async def revoke_all_authorizations(self, address: str): + """ + Revokes all authorizations for a specific address. + """ + authorizations = await self.get_authorizations(self.account.get_address()) + filtered_authorizations = [ + authorization + for authorization in authorizations + if authorization.address != address + ] + await self.update_all_authorizations(filtered_authorizations) diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index cf75d986..11aa08f0 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -3,11 +3,11 @@ import logging import ssl import time +from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Mapping, NoReturn, Optional, Tuple, Union +from typing import Any, Dict, Mapping, NoReturn, Optional, Tuple, Union import aiohttp -from aleph_message import parse_message from aleph_message.models import ( AggregateContent, AggregateMessage, @@ -15,33 +15,32 @@ Chain, ForgetContent, ForgetMessage, - InstanceContent, InstanceMessage, + ItemHash, ItemType, MessageType, PostContent, PostMessage, - ProgramContent, ProgramMessage, StoreContent, StoreMessage, ) from aleph_message.models.execution.base import Encoding, Payment, PaymentType from aleph_message.models.execution.environment import ( - FunctionEnvironment, - MachineResources, + HostRequirements, + HypervisorType, + TrustedExecutionEnvironment, ) -from aleph_message.models.execution.instance import RootfsVolume -from aleph_message.models.execution.program import CodeContent, FunctionRuntime -from aleph_message.models.execution.volume import MachineVolume, ParentVolume from aleph_message.status import MessageStatus from ..conf import settings from ..exceptions import BroadcastError, InsufficientFundsError, InvalidMessageError -from ..types import Account, StorageEnum -from ..utils import extended_json_encoder, parse_volume +from ..types import Account, StorageEnum, TokenType +from ..utils import extended_json_encoder, make_instance_content, make_program_content from .abstract import AuthenticatedAlephClient from .http import AlephHttpClient +from .services.authenticated_port_forwarder import AuthenticatedPortForwarder +from .services.authenticated_voucher import AuthenticatedVoucher logger = logging.getLogger(__name__) @@ -85,7 +84,11 @@ def __init__( ) self.account = account - async def __aenter__(self) -> "AuthenticatedAlephHttpClient": + async def __aenter__(self): + await super().__aenter__() + # Override services with authenticated versions + self.port_forwarder = AuthenticatedPortForwarder(self) + self.voucher = AuthenticatedVoucher(self) return self async def ipfs_push(self, content: Mapping) -> str: @@ -114,14 +117,14 @@ async def storage_push(self, content: Mapping) -> str: resp.raise_for_status() return (await resp.json()).get("hash") - async def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: + async def ipfs_push_file(self, file_content: bytes) -> str: """ Push a file to the IPFS service. :param file_content: The file content to upload """ data = aiohttp.FormData() - data.add_field("file", file_content) + data.add_field("file", BytesIO(file_content)) url = "/api/v0/ipfs/add_file" logger.debug(f"Pushing file to IPFS on {url}") @@ -130,12 +133,12 @@ async def ipfs_push_file(self, file_content: Union[str, bytes]) -> str: resp.raise_for_status() return (await resp.json()).get("hash") - async def storage_push_file(self, file_content) -> str: + async def storage_push_file(self, file_content: bytes) -> Optional[str]: """ Push a file to the storage service. """ data = aiohttp.FormData() - data.add_field("file", file_content) + data.add_field("file", BytesIO(file_content)) url = "/api/v0/storage/add_file" logger.debug(f"Posting file on {url}") @@ -258,7 +261,7 @@ async def _broadcast( url = "/api/v0/messages" logger.debug(f"Posting message on {url}") - message_dict = message.dict(include=self.BROADCAST_MESSAGE_FIELDS) + message_dict = message.model_dump(include=self.BROADCAST_MESSAGE_FIELDS) async with self.http_session.post( url, json={ @@ -284,7 +287,7 @@ async def create_post( post_type: str, ref: Optional[str] = None, address: Optional[str] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, inline: bool = True, storage_engine: StorageEnum = StorageEnum.storage, sync: bool = False, @@ -300,21 +303,21 @@ async def create_post( ) message, status, _ = await self.submit( - content=content.dict(exclude_none=True), + content=content.model_dump(exclude_none=True), message_type=MessageType.post, channel=channel, allow_inlining=inline, storage_engine=storage_engine, sync=sync, ) - return message, status + return message, status # type: ignore async def create_aggregate( self, key: str, - content: Mapping[str, Any], + content: dict[str, Any], address: Optional[str] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, inline: bool = True, sync: bool = False, ) -> Tuple[AggregateMessage, MessageStatus]: @@ -328,13 +331,13 @@ async def create_aggregate( ) message, status, _ = await self.submit( - content=content_.dict(exclude_none=True), + content=content_.model_dump(exclude_none=True), message_type=MessageType.aggregate, channel=channel, allow_inlining=inline, sync=sync, ) - return message, status + return message, status # type: ignore async def create_store( self, @@ -346,10 +349,14 @@ async def create_store( ref: Optional[str] = None, storage_engine: StorageEnum = StorageEnum.storage, extra_fields: Optional[dict] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, sync: bool = False, + payment: Optional[Payment] = None, ) -> Tuple[StoreMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() + payment = payment or Payment( + chain=Chain.ETH, type=PaymentType.hold, receiver=None + ) extra_fields = extra_fields or {} @@ -372,6 +379,7 @@ async def create_store( extra_fields=extra_fields, channel=channel, sync=sync, + payment=payment, ) elif storage_engine == StorageEnum.ipfs: # We do not support authenticated upload for IPFS yet. Use the legacy method @@ -395,123 +403,105 @@ async def create_store( "item_type": storage_engine, "item_hash": file_hash, "time": time.time(), + "payment": payment, } if extra_fields is not None: values.update(extra_fields) - content = StoreContent(**values) + content = StoreContent.model_validate(values) message, status, _ = await self.submit( - content=content.dict(exclude_none=True), + content=content.model_dump(exclude_none=True), message_type=MessageType.store, channel=channel, allow_inlining=True, sync=sync, ) - return message, status + return message, status # type: ignore async def create_program( self, program_ref: str, entrypoint: str, runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, + payment: Optional[Payment] = None, vcpus: Optional[int] = None, + memory: Optional[int] = None, timeout_seconds: Optional[float] = None, - persistent: bool = False, - allow_amend: bool = False, internet: bool = True, + allow_amend: bool = False, aleph_api: bool = True, encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, - metadata: Optional[Mapping[str, Any]] = None, + persistent: bool = False, + volumes: Optional[list[Mapping]] = None, + environment_variables: Optional[dict[str, str]] = None, + subscriptions: Optional[list[dict]] = None, + sync: bool = False, + channel: Optional[str] = settings.DEFAULT_CHANNEL, + storage_engine: StorageEnum = StorageEnum.storage, ) -> Tuple[ProgramMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() - volumes = volumes if volumes is not None else [] - memory = memory or settings.DEFAULT_VM_MEMORY - vcpus = vcpus or settings.DEFAULT_VM_VCPUS - timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT - - # TODO: Check that program_ref, runtime and data_ref exist - - # Register the different ways to trigger a VM - if subscriptions: - # Trigger on HTTP calls and on aleph.im message subscriptions. - triggers = { - "http": True, - "persistent": persistent, - "message": subscriptions, - } - else: - # Trigger on HTTP calls. - triggers = {"http": True, "persistent": persistent} - - volumes: List[MachineVolume] = [parse_volume(volume) for volume in volumes] - - content = ProgramContent( - type="vm-function", + content = make_program_content( + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + metadata=metadata, address=address, + payment=payment, + vcpus=vcpus, + memory=memory, + timeout_seconds=timeout_seconds, + internet=internet, + aleph_api=aleph_api, allow_amend=allow_amend, - code=CodeContent( - encoding=encoding, - entrypoint=entrypoint, - ref=program_ref, - use_latest=True, - ), - on=triggers, - environment=FunctionEnvironment( - reproducible=False, - internet=internet, - aleph_api=aleph_api, - ), - variables=environment_variables, - resources=MachineResources( - vcpus=vcpus, - memory=memory, - seconds=timeout_seconds, - ), - runtime=FunctionRuntime( - ref=runtime, - use_latest=True, - comment=( - "Official aleph.im runtime" - if runtime == settings.DEFAULT_RUNTIME_ID - else "" - ), - ), - volumes=[parse_volume(volume) for volume in volumes], - time=time.time(), - metadata=metadata, + encoding=encoding, + persistent=persistent, + volumes=volumes, + environment_variables=environment_variables, + subscriptions=subscriptions, ) - # Ensure that the version of aleph-message used supports the field. - assert content.on.persistent == persistent - message, status, _ = await self.submit( - content=content.dict(exclude_none=True), + content=content.model_dump(exclude_none=True), message_type=MessageType.program, channel=channel, storage_engine=storage_engine, sync=sync, + raise_on_rejected=False, ) - return message, status + if status in (MessageStatus.PROCESSED, MessageStatus.PENDING): + return message, status # type: ignore + + # get the reason for rejection + rejected_message = await self.get_message_error(message.item_hash) + assert rejected_message, "No rejected message found" + error_code = rejected_message["error_code"] + if error_code == 5: + # not enough balance + details = rejected_message["details"] + errors = details["errors"] + error = errors[0] + account_balance = float(error["account_balance"]) + required_balance = float(error["required_balance"]) + raise InsufficientFundsError( + token_type=TokenType.ALEPH, + required_funds=required_balance, + available_funds=account_balance, + ) + else: + raise ValueError(f"Unknown error code {error_code}: {rejected_message}") async def create_instance( self, rootfs: str, rootfs_size: int, - rootfs_name: str, payment: Optional[Payment] = None, - environment_variables: Optional[Mapping[str, str]] = None, + environment_variables: Optional[dict[str, str]] = None, storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, address: Optional[str] = None, sync: bool = False, memory: Optional[int] = None, @@ -520,57 +510,38 @@ async def create_instance( allow_amend: bool = False, internet: bool = True, aleph_api: bool = True, - volumes: Optional[List[Mapping]] = None, + hypervisor: Optional[HypervisorType] = None, + trusted_execution: Optional[TrustedExecutionEnvironment] = None, + volumes: Optional[list[Mapping]] = None, volume_persistence: str = "host", - ssh_keys: Optional[List[str]] = None, - metadata: Optional[Mapping[str, Any]] = None, + ssh_keys: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + requirements: Optional[HostRequirements] = None, ) -> Tuple[InstanceMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() - volumes = volumes if volumes is not None else [] - memory = memory or settings.DEFAULT_VM_MEMORY - vcpus = vcpus or settings.DEFAULT_VM_VCPUS - timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT - - payment = payment or Payment(chain=Chain.ETH, type=PaymentType.hold) - - content = InstanceContent( + content = make_instance_content( + rootfs=rootfs, + rootfs_size=rootfs_size, + payment=payment, + environment_variables=environment_variables, address=address, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, allow_amend=allow_amend, - environment=FunctionEnvironment( - reproducible=False, - internet=internet, - aleph_api=aleph_api, - ), - variables=environment_variables, - resources=MachineResources( - vcpus=vcpus, - memory=memory, - seconds=timeout_seconds, - ), - rootfs=RootfsVolume( - parent=ParentVolume( - ref=rootfs, - use_latest=True, - ), - name=rootfs_name, - size_mib=rootfs_size, - persistence="host", - use_latest=True, - comment=( - "Official Aleph Debian root filesystem" - if rootfs == settings.DEFAULT_RUNTIME_ID - else "" - ), - ), - volumes=[parse_volume(volume) for volume in volumes], - time=time.time(), - authorized_keys=ssh_keys, + internet=internet, + aleph_api=aleph_api, + hypervisor=hypervisor, + trusted_execution=trusted_execution, + volumes=volumes, + ssh_keys=ssh_keys, metadata=metadata, - payment=payment, + requirements=requirements, ) + message, status, response = await self.submit( - content=content.dict(exclude_none=True), + content=content.model_dump(exclude_none=True), message_type=MessageType.instance, channel=channel, storage_engine=storage_engine, @@ -578,7 +549,7 @@ async def create_instance( raise_on_rejected=False, ) if status in (MessageStatus.PROCESSED, MessageStatus.PENDING): - return message, status + return message, status # type: ignore # get the reason for rejection rejected_message = await self.get_message_error(message.item_hash) @@ -592,17 +563,19 @@ async def create_instance( account_balance = float(error["account_balance"]) required_balance = float(error["required_balance"]) raise InsufficientFundsError( - required_funds=required_balance, available_funds=account_balance + token_type=TokenType.ALEPH, + required_funds=required_balance, + available_funds=account_balance, ) else: raise ValueError(f"Unknown error code {error_code}: {rejected_message}") async def forget( self, - hashes: List[str], + hashes: list[ItemHash], reason: Optional[str], storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, address: Optional[str] = None, sync: bool = False, ) -> Tuple[ForgetMessage, MessageStatus]: @@ -616,74 +589,26 @@ async def forget( ) message, status, _ = await self.submit( - content=content.dict(exclude_none=True), + content=content.model_dump(exclude_none=True), message_type=MessageType.forget, channel=channel, storage_engine=storage_engine, allow_inlining=True, sync=sync, ) - return message, status - - @staticmethod - def compute_sha256(s: str) -> str: - h = hashlib.sha256() - h.update(s.encode("utf-8")) - return h.hexdigest() - - async def _prepare_aleph_message( - self, - message_type: MessageType, - content: Dict[str, Any], - channel: Optional[str], - allow_inlining: bool = True, - storage_engine: StorageEnum = StorageEnum.storage, - ) -> AlephMessage: - message_dict: Dict[str, Any] = { - "sender": self.account.get_address(), - "chain": self.account.CHAIN, - "type": message_type, - "content": content, - "time": time.time(), - "channel": channel, - } - - # Use the Pydantic encoder to serialize types like UUID, datetimes, etc. - item_content: str = json.dumps( - content, separators=(",", ":"), default=extended_json_encoder - ) - - if allow_inlining and (len(item_content) < settings.MAX_INLINE_SIZE): - message_dict["item_content"] = item_content - message_dict["item_hash"] = self.compute_sha256(item_content) - message_dict["item_type"] = ItemType.inline - else: - if storage_engine == StorageEnum.ipfs: - message_dict["item_hash"] = await self.ipfs_push( - content=content, - ) - message_dict["item_type"] = ItemType.ipfs - else: # storage - assert storage_engine == StorageEnum.storage - message_dict["item_hash"] = await self.storage_push( - content=content, - ) - message_dict["item_type"] = ItemType.storage - - message_dict = await self.account.sign_message(message_dict) - return parse_message(message_dict) + return message, status # type: ignore async def submit( self, content: Dict[str, Any], message_type: MessageType, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, storage_engine: StorageEnum = StorageEnum.storage, allow_inlining: bool = True, sync: bool = False, raise_on_rejected: bool = True, ) -> Tuple[AlephMessage, MessageStatus, Optional[Dict[str, Any]]]: - message = await self._prepare_aleph_message( + message = await self.generate_signed_message( message_type=message_type, content=content, channel=channel, @@ -699,20 +624,20 @@ async def _storage_push_file_with_message( self, file_content: bytes, store_content: StoreContent, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: """Push a file to the storage service.""" data = aiohttp.FormData() # Prepare the STORE message - message = await self._prepare_aleph_message( + message = await self.generate_signed_message( message_type=MessageType.store, - content=store_content.dict(exclude_none=True), + content=store_content.model_dump(exclude_none=True), channel=channel, ) metadata = { - "message": message.dict(exclude_none=True), + "message": message.model_dump(exclude_none=True), "sync": sync, } data.add_field( @@ -721,7 +646,7 @@ async def _storage_push_file_with_message( content_type="application/json", ) # Add the file - data.add_field("file", file_content) + data.add_field("file", BytesIO(file_content)) url = "/api/v0/storage/add_file" logger.debug(f"Posting file on {url}") @@ -731,7 +656,7 @@ async def _storage_push_file_with_message( message_status = ( MessageStatus.PENDING if resp.status == 202 else MessageStatus.PROCESSED ) - return message, message_status + return message, message_status # type: ignore async def _upload_file_native( self, @@ -740,8 +665,9 @@ async def _upload_file_native( guess_mime_type: bool = False, ref: Optional[str] = None, extra_fields: Optional[dict] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, sync: bool = False, + payment: Optional[Payment] = None, ) -> Tuple[StoreMessage, MessageStatus]: file_hash = hashlib.sha256(file_content).hexdigest() if magic and guess_mime_type: @@ -752,11 +678,12 @@ async def _upload_file_native( store_content = StoreContent( address=address, ref=ref, - item_type=StorageEnum.storage, - item_hash=file_hash, - mime_type=mime_type, + item_type=ItemType.storage, + item_hash=ItemHash(file_hash), + mime_type=mime_type, # type: ignore time=time.time(), - **extra_fields, + payment=payment, + **(extra_fields or {}), ) message, _ = await self._storage_push_file_with_message( file_content=file_content, diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index c79a07a5..6fa6d4ac 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -1,25 +1,75 @@ import json import logging +import os.path import ssl +import time from io import BytesIO -from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Type, Union +from pathlib import Path +from typing import ( + Any, + AsyncIterable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + Union, + overload, +) import aiohttp +from aiohttp import ClientResponseError +from aiohttp.web import HTTPNotFound from aleph_message import parse_message -from aleph_message.models import AlephMessage, ItemHash, ItemType +from aleph_message.models import ( + AlephMessage, + Chain, + ExecutableContent, + ItemHash, + ItemType, + MessageType, + ProgramContent, +) +from aleph_message.status import MessageStatus from pydantic import ValidationError +from aleph.sdk.client.services.crn import Crn +from aleph.sdk.client.services.dns import DNS +from aleph.sdk.client.services.instance import Instance +from aleph.sdk.client.services.port_forwarder import PortForwarder +from aleph.sdk.client.services.pricing import Pricing +from aleph.sdk.client.services.scheduler import Scheduler +from aleph.sdk.client.services.settings import Settings as NetworkSettingsService +from aleph.sdk.client.services.voucher import Vouchers + from ..conf import settings -from ..exceptions import FileTooLarge, ForgottenMessageError, MessageNotFoundError -from ..query.filters import MessageFilter, PostFilter -from ..query.responses import MessagesResponse, Post, PostsResponse -from ..types import GenericMessage +from ..exceptions import ( + FileTooLarge, + ForgottenMessageError, + InvalidHashError, + MessageNotFoundError, + RemovedMessageError, + ResourceNotFoundError, +) +from ..query.filters import BalanceFilter, MessageFilter, PostFilter +from ..query.responses import ( + BalanceResponse, + CreditsHistoryResponse, + MessagesResponse, + Post, + PostsResponse, + PriceResponse, +) +from ..types import GenericMessage, StoredContent from ..utils import ( Writable, check_unix_socket_valid, + compute_sha256, copy_async_readable_to_buffer, extended_json_encoder, get_message_type_value, + safe_getattr, ) from .abstract import AlephClient @@ -28,7 +78,7 @@ class AlephHttpClient(AlephClient): api_server: str - http_session: aiohttp.ClientSession + _http_session: Optional[aiohttp.ClientSession] def __init__( self, @@ -46,41 +96,67 @@ def __init__( if not self.api_server: raise ValueError("Missing API host") - connector: Union[aiohttp.BaseConnector, None] + self.connector: Union[aiohttp.BaseConnector, None] unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET + if ssl_context: - connector = aiohttp.TCPConnector(ssl=ssl_context) + self.connector = aiohttp.TCPConnector(ssl=ssl_context) elif unix_socket_path and allow_unix_sockets: check_unix_socket_valid(unix_socket_path) - connector = aiohttp.UnixConnector(path=unix_socket_path) + self.connector = aiohttp.UnixConnector(path=unix_socket_path) else: - connector = None - - # ClientSession timeout defaults to a private sentinel object and may not be None. - self.http_session = ( - aiohttp.ClientSession( - base_url=self.api_server, - connector=connector, - timeout=timeout, - json_serialize=extended_json_encoder, + self.connector = None + + self.timeout = timeout + self._http_session = None + + @property + def http_session(self) -> aiohttp.ClientSession: + if self._http_session is None: + raise Exception( + f"{self.__class__.__name__} can only be using within an async context manager.\n\n" + "Please use it this way:\n\n" + f" async with {self.__class__.__name__}(...) as client:" ) - if timeout - else aiohttp.ClientSession( - base_url=self.api_server, - connector=connector, - json_serialize=lambda obj: json.dumps( - obj, default=extended_json_encoder - ), + + return self._http_session + + async def __aenter__(self): + if self._http_session is None: + self._http_session = ( + aiohttp.ClientSession( + base_url=self.api_server, + connector=self.connector, + timeout=self.timeout, + json_serialize=extended_json_encoder, + ) + if self.timeout + else aiohttp.ClientSession( + base_url=self.api_server, + connector=self.connector, + json_serialize=lambda obj: json.dumps( + obj, default=extended_json_encoder + ), + ) ) - ) - async def __aenter__(self) -> "AlephHttpClient": + # Initialize default services + self.dns = DNS(self) + self.port_forwarder = PortForwarder(self) + self.crn = Crn(self) + self.scheduler = Scheduler(self) + self.instance = Instance(self) + self.pricing = Pricing(self) + self.voucher = Vouchers(self) + self.network_settings = NetworkSettingsService(self) return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.http_session.close() + # Avoid cascade in error handling + if self._http_session is not None: + await self._http_session.close() - async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: + async def _fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: params: Dict[str, Any] = {"keys": key} async with self.http_session.get( @@ -89,9 +165,10 @@ async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: resp.raise_for_status() result = await resp.json() data = result.get("data", dict()) - return data.get(key) + final_result = data.get(key) + return final_result - async def fetch_aggregates( + async def _fetch_aggregates( self, address: str, keys: Optional[Iterable[str]] = None ) -> Dict[str, Dict]: keys_str = ",".join(keys) if keys else "" @@ -108,6 +185,32 @@ async def fetch_aggregates( data = result.get("data", dict()) return data + async def get_aggregate(self, address: str, key: str) -> Optional[Dict[str, Dict]]: + try: + return await self.fetch_aggregate(address=address, key=key) + except ClientResponseError as e: + if e.status == 404: + return None + raise + + async def get_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None + ) -> Optional[Dict[str, Dict]]: + try: + return await self.fetch_aggregates(address=address, keys=keys) + except ClientResponseError as e: + if e.status == 404: + return None + raise + + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: + return await self._fetch_aggregate(address=address, key=key) + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None + ) -> Dict[str, Dict]: + return await self._fetch_aggregates(address=address, keys=keys) + async def get_posts( self, page_size: int = 200, @@ -143,7 +246,7 @@ async def get_posts( posts: List[Post] = [] for post_raw in posts_raw: try: - posts.append(Post.parse_obj(post_raw)) + posts.append(Post.model_validate(post_raw)) except ValidationError as e: if not ignore_invalid_messages: raise e @@ -183,6 +286,9 @@ async def download_file_to_buffer( ) else: raise FileTooLarge(f"The file from {file_hash} is too large") + if response.status == 404: + raise ResourceNotFoundError() + return None async def download_file_ipfs_to_buffer( self, @@ -206,14 +312,11 @@ async def download_file_ipfs_to_buffer( else: response.raise_for_status() - async def download_file( - self, - file_hash: str, - ) -> bytes: + async def download_file(self, file_hash: str) -> bytes: """ Get a file from the storage engine as raw bytes. - Warning: Downloading large files can be slow and memory intensive. + Warning: Downloading large files can be slow and memory intensive. Use `download_file_to()` to download them directly to disk instead. :param file_hash: The hash of the file to retrieve. """ @@ -221,6 +324,30 @@ async def download_file( await self.download_file_to_buffer(file_hash, output_buffer=buffer) return buffer.getvalue() + async def download_file_to_path( + self, + file_hash: str, + path: Union[Path, str], + ) -> Path: + """ + Download a file from the storage engine to given path. + + :param file_hash: The hash of the file to retrieve. + :param path: The path to which the file should be saved. + """ + if not isinstance(path, Path): + path = Path(path) + + if not os.path.exists(path): + dir_path = os.path.dirname(path) + if dir_path: + os.makedirs(dir_path, exist_ok=True) + + with open(path, "wb") as file_buffer: + await self.download_file_to_buffer(file_hash, output_buffer=file_buffer) + + return path + async def download_file_ipfs( self, file_hash: str, @@ -298,23 +425,43 @@ async def get_messages( pagination_item=response_json["pagination_item"], ) + @overload + async def get_message( # type: ignore + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + ) -> GenericMessage: ... + + @overload async def get_message( self, item_hash: str, message_type: Optional[Type[GenericMessage]] = None, - ) -> GenericMessage: + with_status: bool = False, + ) -> Tuple[GenericMessage, MessageStatus]: ... + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + with_status: bool = False, + ) -> Union[GenericMessage, Tuple[GenericMessage, MessageStatus]]: async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp: try: resp.raise_for_status() except aiohttp.ClientResponseError as e: if e.status == 404: - raise MessageNotFoundError(f"No such hash {item_hash}") + raise MessageNotFoundError(f"No such hash {item_hash}") from e raise e message_raw = await resp.json() if message_raw["status"] == "forgotten": raise ForgottenMessageError( f"The requested message {message_raw['item_hash']} has been forgotten by {', '.join(message_raw['forgotten_by'])}" ) + if message_raw["status"] == "removed": + raise RemovedMessageError( + f"The requested message {message_raw['item_hash']} has been removed by {', '.join(message_raw['reason'])}" + ) message = parse_message(message_raw["message"]) if message_type: expected_type = get_message_type_value(message_type) @@ -323,7 +470,10 @@ async def get_message( f"The message type '{message.type}' " f"does not match the expected type '{expected_type}'" ) - return message + if with_status: + return message, message_raw["status"] # type: ignore + else: + return message # type: ignore async def get_message_error( self, @@ -341,6 +491,10 @@ async def get_message_error( raise ForgottenMessageError( f"The requested message {message_raw['item_hash']} has been forgotten by {', '.join(message_raw['forgotten_by'])}" ) + if message_raw["status"] == "removed": + raise RemovedMessageError( + f"The requested message {message_raw['item_hash']} has been removed by {', '.join(message_raw['reason'])}" + ) if message_raw["status"] != "rejected": return None return { @@ -369,3 +523,207 @@ async def watch_messages( yield parse_message(data) elif msg.type == aiohttp.WSMsgType.ERROR: break + + async def get_store_estimated_price( + self, + storage_size_mib: int, + ) -> PriceResponse: + """ + Get the estimated price for a store operation. + + :param storage_size_mib: size in mib you want to store + :return: Price response with cost information + """ + content = { + "address": "0xWeDoNotNeedARealAddress", + "time": time.time(), + "item_type": ItemType.storage, + "estimated_size_mib": storage_size_mib, + "item_hash": compute_sha256("dummy_value"), + } + + item_content: str = json.dumps( + content, + separators=(",", ":"), + default=extended_json_encoder, + ) + + message_dict = dict( + sender=content["address"], + chain=Chain.ETH, + type=MessageType.store, + content=content, + item_content=item_content, + time=time.time(), + channel=settings.DEFAULT_CHANNEL, + item_type=ItemType.inline, + item_hash=compute_sha256(item_content), + signature="0x" + "0" * 130, # Add a dummy signature to pass validation + ) + + message = parse_message(message_dict) + + async with self.http_session.post( + "/api/v0/price/estimate", json=dict(message=message) + ) as resp: + try: + resp.raise_for_status() + response_json = await resp.json() + cost = response_json.get("cost", None) + + return PriceResponse( + cost=cost, + required_tokens=response_json["required_tokens"], + payment_type=response_json["payment_type"], + ) + except aiohttp.ClientResponseError as e: + raise e + + async def get_estimated_price( + self, + content: ExecutableContent, + ) -> PriceResponse: + cleaned_content = content.model_dump(exclude_none=True) + item_content: str = json.dumps( + cleaned_content, + separators=(",", ":"), + default=extended_json_encoder, + ) + message_dict = dict( + sender=content.address, + chain=Chain.ETH, + type=( + MessageType.program + if isinstance(content, ProgramContent) + else MessageType.instance + ), + content=cleaned_content, + item_content=item_content, + time=time.time(), + channel=settings.DEFAULT_CHANNEL, + item_type=ItemType.inline, + item_hash=compute_sha256(item_content), + signature="0x" + "0" * 130, # Add a dummy signature to pass validation + ) + + message = parse_message(message_dict) + + async with self.http_session.post( + "/api/v0/price/estimate", json=dict(message=message) + ) as resp: + try: + resp.raise_for_status() + response_json = await resp.json() + cost = response_json.get("cost", None) + + return PriceResponse( + cost=cost, + required_tokens=response_json["required_tokens"], + payment_type=response_json["payment_type"], + ) + except aiohttp.ClientResponseError as e: + raise e + + async def get_program_price(self, item_hash: str) -> PriceResponse: + async with self.http_session.get(f"/api/v0/price/{item_hash}") as resp: + try: + resp.raise_for_status() + response_json = await resp.json() + cost = response_json.get("cost", None) + required_tokens = response_json["required_tokens"] + + return PriceResponse( + required_tokens=required_tokens, + cost=cost, + payment_type=response_json["payment_type"], + ) + except aiohttp.ClientResponseError as e: + if e.status == 400: + raise InvalidHashError(f"Bad request or no such hash {item_hash}") + raise e + + async def get_message_status(self, item_hash: str) -> MessageStatus: + """return Status of a message""" + async with self.http_session.get( + f"/api/v0/messages/{item_hash}/status" + ) as resp: + if resp.status == HTTPNotFound.status_code: + raise MessageNotFoundError(f"No such hash {item_hash}") + resp.raise_for_status() + result = await resp.json() + return MessageStatus(result["status"]) + + async def get_stored_content( + self, + item_hash: str, + ) -> StoredContent: + """return the underlying content for a store message""" + + result, resp = None, None + try: + message: AlephMessage + message, status = await self.get_message( + item_hash=ItemHash(item_hash), with_status=True + ) + if status not in [MessageStatus.PROCESSED, MessageStatus.REMOVING]: + resp = f"Invalid message status: {status}" + elif message.type != MessageType.store: + resp = f"Invalid message type: {message.type}" + elif not message.content.item_hash: + resp = f"Invalid CID: {message.content.item_hash}" + else: + filename = safe_getattr(message.content, "metadata.name") + item_hash = message.content.item_hash + url = ( + f"{self.api_server}/api/v0/storage/raw/" + if len(item_hash) == 64 + else settings.IPFS_GATEWAY + ) + item_hash + result = StoredContent( + filename=filename, hash=item_hash, url=url, error=None + ) + except MessageNotFoundError: + resp = f"Message not found: {item_hash}" + except ForgottenMessageError: + resp = f"Message forgotten: {item_hash}" + except RemovedMessageError as e: + resp = f"Message resources not available {item_hash}: {str(e)}" + return ( + result + if result + else StoredContent(error=resp, filename=None, hash=None, url=None) + ) + + async def get_credit_history( + self, + address: str, + page_size: int = 200, + page: int = 1, + ) -> CreditsHistoryResponse: + """Return List of credits history for a specific addresses""" + + params = { + "page": str(page), + "pagination": str(page_size), + } + + async with self.http_session.get( + f"/api/v0/addresses/{address}/credit_history", params=params + ) as resp: + resp.raise_for_status() + result = await resp.json() + return CreditsHistoryResponse.model_validate(result) + + async def get_balances( + self, + address: str, + filter: Optional[BalanceFilter] = None, + ) -> BalanceResponse: + + async with self.http_session.get( + f"/api/v0/addresses/{address}/balance", + params=filter.as_http_params() if filter else None, + ) as resp: + resp.raise_for_status() + result = await resp.json() + return BalanceResponse.model_validate(result) diff --git a/src/aleph/sdk/client/services/__init__.py b/src/aleph/sdk/client/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/client/services/authenticated_port_forwarder.py b/src/aleph/sdk/client/services/authenticated_port_forwarder.py new file mode 100644 index 00000000..765ac2f1 --- /dev/null +++ b/src/aleph/sdk/client/services/authenticated_port_forwarder.py @@ -0,0 +1,190 @@ +from typing import TYPE_CHECKING, Optional, Tuple + +from aleph_message.models import AggregateMessage, ItemHash +from aleph_message.status import MessageStatus + +from aleph.sdk.client.services.base import AggregateConfig +from aleph.sdk.client.services.port_forwarder import PortForwarder +from aleph.sdk.exceptions import MessageNotProcessed, NotAuthorize +from aleph.sdk.types import AllForwarders, Ports +from aleph.sdk.utils import safe_getattr + +if TYPE_CHECKING: + from aleph.sdk.client.abstract import AuthenticatedAlephClient + + +class AuthenticatedPortForwarder(PortForwarder): + """ + Authenticated Port Forwarder services with create and update capabilities + """ + + def __init__(self, client: "AuthenticatedAlephClient"): + super().__init__(client) + + async def _verify_status_processed_and_ownership( + self, item_hash: ItemHash + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Verify that the message is well processed (and not rejected / pending), + This also verify the ownership of the message + """ + message: AggregateMessage + status: MessageStatus + message, status = await self._client.get_message( + item_hash=item_hash, + with_status=True, + ) + + # We ensure message is not Rejected (Might not be processed yet) + if status not in [MessageStatus.PROCESSED, MessageStatus.PENDING]: + raise MessageNotProcessed(item_hash=item_hash, status=status) + + message_content = safe_getattr(message, "content") + address = safe_getattr(message_content, "address") + + if ( + not hasattr(self._client, "account") + or address != self._client.account.get_address() + ): + current_address = ( + self._client.account.get_address() + if hasattr(self._client, "account") + else "unknown" + ) + raise NotAuthorize( + item_hash=item_hash, + target_address=address, + current_address=current_address, + ) + return message, status + + async def get_address_ports( + self, address: Optional[str] = None + ) -> AggregateConfig[AllForwarders]: + """ + Get all port forwarding configurations for an address + + Args: + address: The address to fetch configurations for. + If None, uses the authenticated client's account address. + + Returns: + Port forwarding configurations + """ + if address is None: + if not hasattr(self._client, "account") or not self._client.account: + raise ValueError("No account provided and client is not authenticated") + address = self._client.account.get_address() + + return await super().get_address_ports(address=address) + + async def get_ports( + self, item_hash: ItemHash = None, address: Optional[str] = None + ) -> Optional[Ports]: + """ + Get port forwarding configuration for a specific item hash + + Args: + address: The address to fetch configurations for. + If None, uses the authenticated client's account address. + item_hash: The hash of the item to get configuration for + + Returns: + Port configuration if found, otherwise empty Ports object + """ + if address is None: + if not hasattr(self._client, "account") or not self._client.account: + raise ValueError("No account provided and client is not authenticated") + address = self._client.account.get_address() + + if item_hash is None: + raise ValueError("item_hash must be provided") + + return await super().get_ports(address=address, item_hash=item_hash) + + async def create_ports( + self, item_hash: ItemHash, ports: Ports + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Create a new port forwarding configuration for an item hash + + Args: + item_hash: The hash of the item (instance/program/IPFS website) + ports: Dictionary mapping port numbers to PortFlags + + Returns: + Dictionary with the result of the operation + """ + if not hasattr(self._client, "account") or not self._client.account: + raise ValueError("An account is required for this operation") + + # Pre Check + # _, _ = await self._verify_status_processed_and_ownership(item_hash=item_hash) + + content = {str(item_hash): ports.model_dump()} + + # Check if create_aggregate exists on the client + return await self._client.create_aggregate( # type: ignore + key=self.aggregate_key, content=content + ) + + async def update_ports( + self, item_hash: ItemHash, ports: Ports + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Update an existing port forwarding configuration for an item hash + + Args: + item_hash: The hash of the item (instance/program/IPFS website) + ports: Dictionary mapping port numbers to PortFlags + + Returns: + Dictionary with the result of the operation + """ + if not hasattr(self._client, "account") or not self._client.account: + raise ValueError("An account is required for this operation") + + # Pre Check + # _, _ = await self._verify_status_processed_and_ownership(item_hash=item_hash) + + content = {} + + content[str(item_hash)] = ports.model_dump() + + message, status = await self._client.create_aggregate( # type: ignore + key=self.aggregate_key, content=content + ) + + return message, status + + async def delete_ports( + self, item_hash: ItemHash + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Delete port forwarding configuration for an item hash + + Args: + item_hash: The hash of the item (instance/program/IPFS website) to delete configuration for + + Returns: + Dictionary with the result of the operation + """ + if not hasattr(self._client, "account") or not self._client.account: + raise ValueError("An account is required for this operation") + + # Pre Check + # _, _ = await self._verify_status_processed_and_ownership(item_hash=item_hash) + + # Get the Port Config of the item_hash + port: Optional[Ports] = await self.get_ports(item_hash=item_hash) + if not port: + raise + + content = {} + content[str(item_hash)] = port.model_dump() + + # Create a new aggregate with the updated content + message, status = await self._client.create_aggregate( # type: ignore + key=self.aggregate_key, content=content + ) + return message, status diff --git a/src/aleph/sdk/client/services/authenticated_voucher.py b/src/aleph/sdk/client/services/authenticated_voucher.py new file mode 100644 index 00000000..48d7d73d --- /dev/null +++ b/src/aleph/sdk/client/services/authenticated_voucher.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING, Optional, overload + +from typing_extensions import override + +from aleph.sdk.types import Voucher + +from .voucher import Vouchers + +if TYPE_CHECKING: + from aleph.sdk.client.abstract import AuthenticatedAlephClient + + +class AuthenticatedVoucher(Vouchers): + """ + This service is same logic than Vouchers but allow to don't pass address + to use account address + """ + + def __init__(self, client: "AuthenticatedAlephClient"): + super().__init__(client) + + @overload + def _resolve_address(self, address: str) -> str: ... + + @overload + def _resolve_address(self, address: None) -> str: ... + + @override + def _resolve_address(self, address: Optional[str] = None) -> str: + """ + Resolve the address to use. Prefer the provided address, fallback to account. + """ + if address: + return address + if self._client.account: + return self._client.account.get_address() + + raise ValueError("No address provided and no account configured") + + @override + async def get_vouchers(self, address: Optional[str] = None) -> list[Voucher]: + """ + Retrieve all vouchers for the account / specific address, across EVM and Solana chains. + """ + address = address or self._client.account.get_address() + return await super().get_vouchers(address=address) + + @override + async def get_evm_vouchers(self, address: Optional[str] = None) -> list[Voucher]: + """ + Retrieve vouchers specific to EVM chains for a specific address. + """ + address = address or self._client.account.get_address() + return await super().get_evm_vouchers(address=address) + + @override + async def get_solana_vouchers(self, address: Optional[str] = None) -> list[Voucher]: + """ + Fetch Solana vouchers for a specific address. + """ + address = address or self._client.account.get_address() + return await super().get_solana_vouchers(address=address) diff --git a/src/aleph/sdk/client/services/base.py b/src/aleph/sdk/client/services/base.py new file mode 100644 index 00000000..77a9cc0b --- /dev/null +++ b/src/aleph/sdk/client/services/base.py @@ -0,0 +1,42 @@ +from abc import ABC +from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar + +from pydantic import BaseModel + +if TYPE_CHECKING: + from aleph.sdk.client.http import AlephHttpClient + + +T = TypeVar("T", bound=BaseModel) + + +class AggregateConfig(BaseModel, Generic[T]): + """ + A generic container for "aggregate" data of type T. + - `data` will be either None or a list of T-instances. + """ + + data: Optional[List[T]] = None + + +class BaseService(ABC, Generic[T]): + aggregate_key: str + model_cls: Type[T] + + def __init__(self, client: "AlephHttpClient"): + self._client = client + self.model_cls: Type[T] + + async def get_config(self, address: str): + + aggregate_data = await self._client.get_aggregate( + address=address, key=self.aggregate_key + ) + + if aggregate_data: + model_instance = self.model_cls.model_validate(aggregate_data) + config = AggregateConfig[T](data=[model_instance]) + else: + config = AggregateConfig[T](data=None) + + return config diff --git a/src/aleph/sdk/client/services/crn.py b/src/aleph/sdk/client/services/crn.py new file mode 100644 index 00000000..19477cb4 --- /dev/null +++ b/src/aleph/sdk/client/services/crn.py @@ -0,0 +1,396 @@ +from datetime import datetime +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +import aiohttp +from aiohttp.client_exceptions import ClientResponseError +from aleph_message.models import ItemHash +from pydantic import BaseModel, NonNegativeInt, PositiveInt + +from aleph.sdk.conf import settings +from aleph.sdk.exceptions import MethodNotAvailableOnCRN, VmNotFoundOnHost +from aleph.sdk.types import ( + CrnExecutionV1, + CrnExecutionV2, + CrnV1List, + CrnV2List, + DictLikeModel, + VmResources, +) +from aleph.sdk.utils import extract_valid_eth_address, sanitize_url + +if TYPE_CHECKING: + from aleph.sdk.client.http import AlephHttpClient + + +class CpuLoad(BaseModel): + load1: float + load5: float + load15: float + + +class CoreFrequencies(BaseModel): + min: float + max: float + + +class CpuInfo(BaseModel): + count: PositiveInt + load_average: CpuLoad + core_frequencies: CoreFrequencies + + +class CpuProperties(BaseModel): + architecture: str + vendor: str + features: List[str] = [] + + +class MemoryInfo(BaseModel): + total_kB: PositiveInt + available_kB: NonNegativeInt + + +class DiskInfo(BaseModel): + total_kB: PositiveInt + available_kB: NonNegativeInt + + +class UsagePeriod(BaseModel): + start_timestamp: datetime + duration_seconds: NonNegativeInt + + +class Properties(BaseModel): + cpu: CpuProperties + + +class GPU(BaseModel): + vendor: str + model: str + device_name: str + device_class: str + pci_host: str + device_id: str + compatible: bool + + +class GpuUsages(BaseModel): + devices: List[GPU] = [] + available_devices: List[GPU] = [] + + +class SystemUsage(BaseModel): + cpu: CpuInfo + mem: MemoryInfo + disk: DiskInfo + period: UsagePeriod + properties: Properties + gpu: GpuUsages + active: bool + + +class NetworkGPUS(BaseModel): + total_gpu_count: int + available_gpu_count: int + available_gpu_list: dict[str, List[GPU]] # str = node_url + used_gpu_list: dict[str, List[GPU]] # str = node_url + + +class CRN(DictLikeModel): + # This Model work as dict but where we can type what we need / apply logic on top + + # Simplify search + hash: str + name: str + address: str + + gpu_support: Optional[bool] = False + confidential_support: Optional[bool] = False + qemu_support: Optional[bool] = False + system_usage: Optional[SystemUsage] = None + + version: Optional[str] = "0.0.0" + payment_receiver_address: Optional[str] # Can be None if not configured + + +class CrnList(DictLikeModel): + crns: list[CRN] = [] + + @classmethod + def from_api(cls, payload: dict) -> "CrnList": + raw_list = payload.get("crns", []) + crn_list = [ + CRN.model_validate(item) if not isinstance(item, CRN) else item + for item in raw_list + ] + return cls(crns=crn_list) + + def find_gpu_on_network(self): + gpu_count: int = 0 + available_gpu_count: int = 0 + + compatible_gpu: Dict[str, List[GPU]] = {} + available_compatible_gpu: Dict[str, List[GPU]] = {} + + for crn in self.crns: + if not crn.gpu_support: + continue + + # Extracts used GPU + compatible_gpu[crn.address] = [] + for gpu in crn.get("compatible_gpus", []): + compatible_gpu[crn.address].append(GPU.model_validate(gpu)) + gpu_count += 1 + + # Extracts available GPU + available_compatible_gpu[crn.address] = [] + for gpu in crn.get("compatible_available_gpus", []): + available_compatible_gpu[crn.address].append(GPU.model_validate(gpu)) + gpu_count += 1 + available_gpu_count += 1 + + return NetworkGPUS( + total_gpu_count=gpu_count, + available_gpu_count=available_gpu_count, + used_gpu_list=compatible_gpu, + available_gpu_list=available_compatible_gpu, + ) + + def filter_crn( + self, + crn_version: Optional[str] = None, + ipv6: bool = False, + stream_address: bool = False, + confidential: bool = False, + gpu: bool = False, + vm_resources: Optional[VmResources] = None, + ) -> list[CRN]: + """Filter compute resource node list, unfiltered by default. + Args: + crn_version (str): Filter by specific crn version. + ipv6 (bool): Filter invalid IPv6 configuration. + stream_address (bool): Filter invalid payment receiver address. + confidential (bool): Filter by confidential computing support. + gpu (bool): Filter by GPU support. + vm_resources (VmResources): Filter by VM need, vcpus, memory, disk. + Returns: + list[CRN]: List of compute resource nodes. (if no filter applied, return all) + """ + + filtered_crn: list[CRN] = [] + for crn in self.crns: + # Check crn version + if crn_version and (crn.version or "0.0.0") < crn_version: + continue + + # Filter with ipv6 check + if ipv6: + ipv6_check = crn.get("ipv6_check") + + if not ipv6_check or not all(ipv6_check.values()): + continue + + if stream_address and not extract_valid_eth_address( + crn.payment_receiver_address or "" + ): + continue + + # Confidential Filter + if confidential and not crn.confidential_support: + continue + + # Filter with GPU / Available GPU + available_gpu = crn.get("compatible_available_gpus") + if gpu and (not crn.gpu_support or not available_gpu): + continue + + # Filter VM resources + if vm_resources: + crn_usage = crn.system_usage + if not crn_usage: + continue + + # Check CPU count + if crn_usage.cpu.count < vm_resources.vcpus: + continue + + # Convert MiB to kB (1 MiB = 1024 kB) for proper comparison + memory_kb_required = vm_resources.memory * 1024 + disk_kb_required = vm_resources.disk_mib * 1024 + + # Check free memory + if crn_usage.mem.available_kB < memory_kb_required: + continue + + # Check free disk + if crn_usage.disk.available_kB < disk_kb_required: + continue + + filtered_crn.append(crn) + return filtered_crn + + # Find CRN by address + def find_crn_by_address(self, address: str) -> Optional[CRN]: + for crn in self.crns: + if crn.address == sanitize_url(address): + return crn + return None + + # Find CRN by hash + def find_crn_by_hash(self, crn_hash: str) -> Optional[CRN]: + for crn in self.crns: + if crn.hash == crn_hash: + return crn + return None + + def find_crn( + self, + address: Optional[str] = None, + crn_hash: Optional[str] = None, + ) -> Optional[CRN]: + """Find CRN by address or hash (both optional, address priority) + + Args: + address (Optional[str], optional): url of the node. Defaults to None. + crn_hash (Optional[str], optional): hash of the nodes. Defaults to None. + + Returns: + Optional[CRN]: CRN object or None if not found + """ + if address: + return self.find_crn_by_address(address) + if crn_hash: + return self.find_crn_by_hash(crn_hash) + return None + + +class Crn: + """ + This services allow interact with CRNS API + TODO: ADD + /about/executions/details + /about/executions/records + /about/usage/system + /about/certificates + /about/capability + /about/config + /status/check/fastapi + /status/check/fastapi/legacy + /status/check/host + /status/check/version + /status/check/ipv6 + /status/config + """ + + def __init__(self, client: "AlephHttpClient"): + self._client = client + + async def get_last_crn_version(self): + """ + Fetch Last version tag from aleph-vm github repo + """ + # Create a new session for external domain requests + async with aiohttp.ClientSession() as session: + async with session.get(settings.CRN_VERSION_URL) as resp: + resp.raise_for_status() + data = await resp.json() + return data.get("tag_name") + + async def get_crns_list(self, only_active: bool = True) -> CrnList: + """ + Query a persistent VM running on aleph.im to retrieve list of CRNs: + https://crns-list.aleph.sh/crns.json + + Parameters + ---------- + only_active : bool + If True (the default), only return active CRNs (i.e. `filter_inactive=false`). + If False, return all CRNs (i.e. `filter_inactive=true`). + + Returns + ------- + dict + The parsed JSON response from /crns.json. + """ + # Convert bool to string for the query parameter + filter_inactive_str = str(only_active).lower() + params = {"filter_inactive": filter_inactive_str} + + # Create a new session for external domain requests + async with aiohttp.ClientSession() as session: + async with session.get( + sanitize_url(settings.CRN_LIST_URL), params=params + ) as resp: + resp.raise_for_status() + return CrnList.from_api(await resp.json()) + + async def get_active_vms_v2(self, crn_address: str) -> CrnV2List: + endpoint = "/v2/about/executions/list" + + full_url = sanitize_url(crn_address + endpoint) + + async with aiohttp.ClientSession() as session: + async with session.get(full_url) as resp: + resp.raise_for_status() + raw = await resp.json() + vm_mmap = CrnV2List.model_validate(raw) + return vm_mmap + + async def get_active_vms_v1(self, crn_address: str) -> CrnV1List: + endpoint = "/about/executions/list" + + full_url = sanitize_url(crn_address + endpoint) + + async with aiohttp.ClientSession() as session: + async with session.get(full_url) as resp: + resp.raise_for_status() + raw = await resp.json() + vm_map = CrnV1List.model_validate(raw) + return vm_map + + async def get_active_vms(self, crn_address: str) -> Union[CrnV2List, CrnV1List]: + try: + return await self.get_active_vms_v2(crn_address) + except ClientResponseError as e: + if e.status == 404: + return await self.get_active_vms_v1(crn_address) + raise + + async def get_vm( + self, crn_address: str, item_hash: ItemHash + ) -> Optional[Union[CrnExecutionV1, CrnExecutionV2]]: + vms = await self.get_active_vms(crn_address) + + vm_map: Dict[ItemHash, Union[CrnExecutionV1, CrnExecutionV2]] = vms.root + + if item_hash not in vm_map: + return None + + return vm_map[item_hash] + + async def update_instance_config(self, crn_address: str, item_hash: ItemHash): + vm = await self.get_vm(crn_address, item_hash) + + if not vm: + raise VmNotFoundOnHost(crn_url=crn_address, item_hash=item_hash) + + # CRN have two week to upgrade their node, + # So if the CRN does not have the update + # We can't update config + if isinstance(vm, CrnExecutionV1): + raise MethodNotAvailableOnCRN() + + full_url = sanitize_url(crn_address + f"/control/{item_hash}/update") + + async with aiohttp.ClientSession() as session: + async with session.post(full_url) as resp: + resp.raise_for_status() + return await resp.json() + + # Gpu Functions Helper + async def fetch_gpu_on_network( + self, + only_active: bool = True, + ) -> NetworkGPUS: + crn_list = await self.get_crns_list(only_active) + return crn_list.find_gpu_on_network() diff --git a/src/aleph/sdk/client/services/dns.py b/src/aleph/sdk/client/services/dns.py new file mode 100644 index 00000000..95132390 --- /dev/null +++ b/src/aleph/sdk/client/services/dns.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING, List, Optional + +import aiohttp +from aleph_message.models import ItemHash + +from aleph.sdk.conf import settings +from aleph.sdk.types import Dns, DnsListAdapter +from aleph.sdk.utils import sanitize_url + +if TYPE_CHECKING: + from aleph.sdk.client.http import AlephHttpClient + + +class DNS: + """ + This Service mostly made to get active dns for instance: + `https://api.dns.public.aleph.sh/instances/list` + """ + + def __init__(self, client: "AlephHttpClient"): + self._client = client + + async def get_public_dns(self) -> List[Dns]: + """ + Get all the public dns ha + """ + async with aiohttp.ClientSession() as session: + async with session.get(sanitize_url(settings.DNS_API)) as resp: + resp.raise_for_status() + raw = await resp.json() + + return DnsListAdapter.validate_json(raw) + + async def get_public_dns_by_host(self, crn_hostname): + """ + Get all the public dns with filter on crn_url + """ + async with aiohttp.ClientSession() as session: + async with session.get( + sanitize_url(settings.DNS_API), params={"crn_url": crn_hostname} + ) as resp: + resp.raise_for_status() + raw = await resp.json() + + return DnsListAdapter.validate_json(raw) + + async def get_dns_for_instance(self, vm_hash: ItemHash) -> Optional[List[Dns]]: + async with aiohttp.ClientSession() as session: + async with session.get( + sanitize_url(settings.DNS_API), params={"item_hash": vm_hash} + ) as resp: + resp.raise_for_status() + raw = await resp.json() + return DnsListAdapter.validate_json(raw) diff --git a/src/aleph/sdk/client/services/instance.py b/src/aleph/sdk/client/services/instance.py new file mode 100644 index 00000000..9a2dcf20 --- /dev/null +++ b/src/aleph/sdk/client/services/instance.py @@ -0,0 +1,145 @@ +import asyncio +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +from aleph_message.models import InstanceMessage, ItemHash, MessageType, PaymentType +from aleph_message.status import MessageStatus + +from aleph.sdk.client.services.crn import CrnList +from aleph.sdk.query.filters import MessageFilter +from aleph.sdk.query.responses import MessagesResponse + +if TYPE_CHECKING: + from aleph.sdk.client.http import AlephHttpClient + +from aleph.sdk.types import ( + CrnExecutionV1, + CrnExecutionV2, + InstanceAllocationsInfo, + InstanceManual, + InstancesExecutionList, + InstanceWithScheduler, +) +from aleph.sdk.utils import safe_getattr, sanitize_url + + +class Instance: + """ + This is utils functions that used multiple Service + exemple getting info about Allocations / exeuction of any instances (hold or not) + """ + + def __init__(self, client: "AlephHttpClient"): + self._client = client + + async def get_name_of_executable(self, item_hash: ItemHash) -> Optional[str]: + try: + message: Any = await self._client.get_message(item_hash=item_hash) + if hasattr(message, "content") and hasattr(message.content, "metadata"): + return message.content.metadata.get("name") + elif isinstance(message, dict): + # Handle dictionary response format + if "content" in message and isinstance(message["content"], dict): + if "metadata" in message["content"] and isinstance( + message["content"]["metadata"], dict + ): + return message["content"]["metadata"].get("name") + return None + except Exception: + return None + + async def get_instance_allocation_info( + self, msg: InstanceMessage, crn_list: CrnList + ) -> Tuple[InstanceMessage, Union[InstanceManual, InstanceWithScheduler]]: + vm_hash = msg.item_hash + payment_type = safe_getattr(msg, "content.payment.type.value") + firmware = safe_getattr(msg, "content.environment.trusted_execution.firmware") + has_gpu = safe_getattr(msg, "content.requirements.gpu") + + is_hold = payment_type == PaymentType.hold.value + is_conf = bool(firmware and len(firmware) == 64) + + if is_hold and not is_conf and not has_gpu: + alloc = await self._client.scheduler.get_allocation(vm_hash) + info = InstanceWithScheduler(source="scheduler", allocations=alloc) + else: + crn_hash = safe_getattr(msg, "content.requirements.node.node_hash") + node = crn_list.find_crn_by_hash(crn_hash) + url = sanitize_url(node.address) if node else "" + + info = InstanceManual(source="manual", crn_url=url) + return msg, info + + async def get_instances(self, address: str) -> List[InstanceMessage]: + resp: MessagesResponse = await self._client.get_messages( + message_filter=MessageFilter( + message_types=[MessageType.instance], + addresses=[address], + message_statuses=[MessageStatus.PROCESSED, MessageStatus.REMOVING], + ), + page_size=100, + ) + return resp.messages + + async def get_instances_allocations(self, messages_list, only_processed=True): + crn_list = await self._client.crn.get_crns_list(only_active=False) + + tasks = [] + for msg in messages_list: + if only_processed: + status = await self._client.get_message_status(msg.item_hash) + if ( + status != MessageStatus.PROCESSED + and status != MessageStatus.REMOVING + ): + continue + tasks.append(self.get_instance_allocation_info(msg, crn_list)) + + results = await asyncio.gather(*tasks) + + mapping = {ItemHash(msg.item_hash): info for msg, info in results} + + return InstanceAllocationsInfo.model_validate(mapping) + + async def get_instance_executions_info( + self, instances: InstanceAllocationsInfo + ) -> InstancesExecutionList: + async def _fetch( + item_hash: ItemHash, + alloc: Union[InstanceManual, InstanceWithScheduler], + ) -> tuple[str, Optional[Union[CrnExecutionV1, CrnExecutionV2]]]: + """Retrieve the execution record for an item hash.""" + if isinstance(alloc, InstanceManual): + crn_url = sanitize_url(alloc.crn_url) + else: + if not alloc.allocations: + return str(item_hash), None + crn_url = sanitize_url(alloc.allocations.node.url) + + if not crn_url: + return str(item_hash), None + + try: + execution = await self._client.crn.get_vm( + item_hash=item_hash, + crn_address=crn_url, + ) + return str(item_hash), execution + except Exception: + return str(item_hash), None + + fetch_tasks = [] + msg_hash_map = {} + + for item_hash, alloc in instances.root.items(): + fetch_tasks.append(_fetch(item_hash, alloc)) + msg_hash_map[str(item_hash)] = item_hash + + results = await asyncio.gather(*fetch_tasks) + + mapping = { + ItemHash(msg_hash): exec_info + for msg_hash, exec_info in results + if msg_hash is not None and exec_info is not None + } + + return InstancesExecutionList.model_validate(mapping) diff --git a/src/aleph/sdk/client/services/port_forwarder.py b/src/aleph/sdk/client/services/port_forwarder.py new file mode 100644 index 00000000..923d0931 --- /dev/null +++ b/src/aleph/sdk/client/services/port_forwarder.py @@ -0,0 +1,44 @@ +from typing import TYPE_CHECKING, Optional + +from aleph_message.models import ItemHash + +from aleph.sdk.client.services.base import AggregateConfig, BaseService +from aleph.sdk.types import AllForwarders, Ports + +if TYPE_CHECKING: + pass + + +class PortForwarder(BaseService[AllForwarders]): + """ + Ports Forwarder Logic + """ + + aggregate_key = "port-forwarding" + model_cls = AllForwarders + + def __init__(self, client): + super().__init__(client=client) + + async def get_address_ports(self, address: str) -> AggregateConfig[AllForwarders]: + result = await self.get_config(address=address) + return result + + async def get_ports(self, item_hash: ItemHash, address: str) -> Optional[Ports]: + """ + Get Ports Forwarder of Instance / Program / IPFS website from aggregate + """ + ports_config: AggregateConfig[AllForwarders] = await self.get_address_ports( + address=address + ) + + if ports_config.data is None: + return Ports(ports={}) + + for forwarder in ports_config.data: + ports_map = forwarder.root + + if str(item_hash) in ports_map: + return ports_map[str(item_hash)] + + return Ports(ports={}) diff --git a/src/aleph/sdk/client/services/pricing.py b/src/aleph/sdk/client/services/pricing.py new file mode 100644 index 00000000..e2b51c50 --- /dev/null +++ b/src/aleph/sdk/client/services/pricing.py @@ -0,0 +1,235 @@ +import logging +import math +from enum import Enum +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +from aleph.sdk.client.services.base import BaseService +from aleph.sdk.conf import settings + +if TYPE_CHECKING: + pass + +from decimal import Decimal + +from pydantic import BaseModel, RootModel + +logger = logging.getLogger(__name__) + + +class PricingEntity(str, Enum): + STORAGE = "storage" + WEB3_HOSTING = "web3_hosting" + PROGRAM = "program" + PROGRAM_PERSISTENT = "program_persistent" + INSTANCE = "instance" + INSTANCE_CONFIDENTIAL = "instance_confidential" + INSTANCE_GPU_STANDARD = "instance_gpu_standard" + INSTANCE_GPU_PREMIUM = "instance_gpu_premium" + + +class GroupEntity(str, Enum): + STORAGE = "storage" + WEBSITE = "website" + PROGRAM = "program" + INSTANCE = "instance" + CONFIDENTIAL = "confidential" + GPU = "gpu" + ALL = "all" + + +class Price(BaseModel): + payg: Optional[Decimal] = None + holding: Optional[Decimal] = None + fixed: Optional[Decimal] = None + credit: Optional[Decimal] = None + + +class ComputeUnit(BaseModel): + vcpus: int + memory_mib: int + disk_mib: int + + +class TierComputedSpec(ComputeUnit): + ... + gpu_model: Optional[str] + vram: Optional[int] + + +class Tier(BaseModel): + id: str + compute_units: int + vram: Optional[int] = None + model: Optional[str] = None + + def extract_tier_id(self) -> str: + return self.id.split("-", 1)[-1] + + +class PricingPerEntity(BaseModel): + price: Dict[str, Union[Price, Decimal]] + compute_unit: Optional[ComputeUnit] = None + tiers: Optional[List[Tier]] = None + + def _get_nb_compute_units( + self, + vcpus: int = 1, + memory_mib: int = 2048, + ) -> Optional[int]: + if self.compute_unit: + memory = math.ceil(memory_mib / self.compute_unit.memory_mib) + nb_compute = vcpus if vcpus >= memory else memory + return nb_compute + return None + + def get_closest_tier( + self, + vcpus: Optional[int] = None, + memory_mib: Optional[int] = None, + compute_unit: Optional[int] = None, + ): + """Get Closest tier for Program / Instance""" + + # We Calculate Compute Unit requested based on vcpus and memory + computed_cu = None + if vcpus is not None and memory_mib is not None: + computed_cu = self._get_nb_compute_units(vcpus=vcpus, memory_mib=memory_mib) + elif vcpus is not None and self.compute_unit is not None: + computed_cu = self._get_nb_compute_units( + vcpus=vcpus, memory_mib=self.compute_unit.memory_mib + ) + elif memory_mib is not None and self.compute_unit is not None: + computed_cu = self._get_nb_compute_units( + vcpus=self.compute_unit.vcpus, memory_mib=memory_mib + ) + + # Case where Vcpus or memory is given but also a number of CU (case on aleph-client) + cu: Optional[int] = None + if computed_cu is not None and compute_unit is not None: + if computed_cu != compute_unit: + logger.warning( + f"Mismatch in compute units: from CPU/RAM={computed_cu}, given={compute_unit}. " + f"Choosing {max(computed_cu, compute_unit)}." + ) + cu = max(computed_cu, compute_unit) # We trust the bigger trier + else: + cu = compute_unit if compute_unit is not None else computed_cu + + # now tier found + if cu is None: + return None + + # With CU available, choose the closest one + candidates = self.tiers + if candidates is None: + return None + + best_tier = min( + candidates, + key=lambda t: (abs(t.compute_units - cu), -t.compute_units), + ) + return best_tier + + def get_services_specs( + self, + tier: Tier, + ) -> TierComputedSpec: + """ + Calculate ammount of vram / cpu / disk | + gpu model / vram if it GPU instance + """ + if self.compute_unit is None: + raise ValueError("ComputeUnit is required to get service specs") + + cpu = tier.compute_units * self.compute_unit.vcpus + memory_mib = tier.compute_units * self.compute_unit.memory_mib + disk = ( + tier.compute_units * self.compute_unit.disk_mib + ) # Min value disk can be increased + + # Gpu Specs + gpu = None + vram = None + if tier.model and tier.vram: + gpu = tier.model + vram = tier.vram + + return TierComputedSpec( + vcpus=cpu, + memory_mib=memory_mib, + disk_mib=disk, + gpu_model=gpu, + vram=vram, + ) + + +class PricingModel(RootModel[Dict[PricingEntity, PricingPerEntity]]): + def __iter__(self): + return iter(self.root) + + def __getitem__(self, item): + return self.root[item] + + +PRICING_GROUPS: dict[str, list[PricingEntity]] = { + GroupEntity.STORAGE: [PricingEntity.STORAGE], + GroupEntity.WEBSITE: [PricingEntity.WEB3_HOSTING], + GroupEntity.PROGRAM: [PricingEntity.PROGRAM, PricingEntity.PROGRAM_PERSISTENT], + GroupEntity.INSTANCE: [PricingEntity.INSTANCE], + GroupEntity.CONFIDENTIAL: [PricingEntity.INSTANCE_CONFIDENTIAL], + GroupEntity.GPU: [ + PricingEntity.INSTANCE_GPU_STANDARD, + PricingEntity.INSTANCE_GPU_PREMIUM, + ], + GroupEntity.ALL: list(PricingEntity), +} + +PAYG_GROUP: list[PricingEntity] = [ + PricingEntity.INSTANCE, + PricingEntity.INSTANCE_CONFIDENTIAL, + PricingEntity.INSTANCE_GPU_STANDARD, + PricingEntity.INSTANCE_GPU_PREMIUM, +] + + +class Pricing(BaseService[PricingModel]): + """ + This Service handle logic around Pricing + """ + + aggregate_key = "pricing" + model_cls = PricingModel + + def __init__(self, client): + super().__init__(client=client) + + # Config from aggregate + async def get_pricing_aggregate( + self, + ) -> PricingModel: + result = await self.get_config(address=settings.ALEPH_AGGREGATE_ADDRESS) + return result.data[0] + + async def get_pricing_for_services( + self, services: List[PricingEntity], pricing_info: Optional[PricingModel] = None + ) -> Dict[PricingEntity, PricingPerEntity]: + """ + Get pricing information for requested services + + Args: + services: List of pricing entities to get information for + pricing_info: Optional pre-fetched pricing aggregate + + Returns: + Dictionary with pricing information for requested services + """ + if ( + not pricing_info + ): # Avoid reloading aggregate info if there is already fetched + pricing_info = await self.get_pricing_aggregate() + + result = {} + for service in services: + if service in pricing_info: + result[service] = pricing_info[service] + + return result diff --git a/src/aleph/sdk/client/services/scheduler.py b/src/aleph/sdk/client/services/scheduler.py new file mode 100644 index 00000000..fdfaa5bc --- /dev/null +++ b/src/aleph/sdk/client/services/scheduler.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING, Optional + +import aiohttp +from aiohttp import ClientResponseError +from aleph_message.models import ItemHash + +from aleph.sdk.conf import settings +from aleph.sdk.types import AllocationItem, SchedulerNodes, SchedulerPlan +from aleph.sdk.utils import sanitize_url + +if TYPE_CHECKING: + from aleph.sdk.client.http import AlephHttpClient + + +class Scheduler: + """ + This Service is made to interact with scheduler API: + `https://scheduler.api.aleph.cloud/` + """ + + def __init__(self, client: "AlephHttpClient"): + self._client = client + + async def get_plan(self) -> SchedulerPlan: + url = f"{sanitize_url(settings.SCHEDULER_URL)}/api/v0/plan" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + resp.raise_for_status() + raw = await resp.json() + return SchedulerPlan.model_validate(raw) + + async def get_nodes(self) -> SchedulerNodes: + url = f"{sanitize_url(settings.SCHEDULER_URL)}/api/v0/nodes" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + resp.raise_for_status() + raw = await resp.json() + + return SchedulerNodes.model_validate(raw) + + async def get_allocation(self, vm_hash: ItemHash) -> Optional[AllocationItem]: + """ + Fetch allocation information for a given VM hash. + """ + url = f"{sanitize_url(settings.SCHEDULER_URL)}/api/v0/allocation/{vm_hash}" + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + resp.raise_for_status() + payload = await resp.json() + return AllocationItem.model_validate(payload) + except ClientResponseError as e: + if e.status == 404: # Allocation can't be find on scheduler + return None + raise e diff --git a/src/aleph/sdk/client/services/settings.py b/src/aleph/sdk/client/services/settings.py new file mode 100644 index 00000000..9f4de76b --- /dev/null +++ b/src/aleph/sdk/client/services/settings.py @@ -0,0 +1,40 @@ +from typing import List + +from pydantic import BaseModel + +from aleph.sdk.conf import settings + +from .base import BaseService + + +class NetworkAvailableGpu(BaseModel): + name: str + model: str + vendor: str + device_id: str + + +class NetworkSettingsModel(BaseModel): + compatible_gpus: List[NetworkAvailableGpu] + last_crn_version: str + community_wallet_address: str + community_wallet_timestamp: int + + +class Settings(BaseService[NetworkSettingsModel]): + """ + This Service handle logic around Pricing + """ + + aggregate_key = "settings" + model_cls = NetworkSettingsModel + + def __init__(self, client): + super().__init__(client=client) + + # Config from aggregate + async def get_settings_aggregate( + self, + ) -> NetworkSettingsModel: + result = await self.get_config(address=settings.ALEPH_AGGREGATE_ADDRESS) + return result.data[0] diff --git a/src/aleph/sdk/client/services/voucher.py b/src/aleph/sdk/client/services/voucher.py new file mode 100644 index 00000000..eef351c4 --- /dev/null +++ b/src/aleph/sdk/client/services/voucher.py @@ -0,0 +1,164 @@ +from typing import Optional + +import aiohttp +from aiohttp import ClientResponseError +from aleph_message.models import Chain + +from aleph.sdk.conf import settings +from aleph.sdk.query.filters import PostFilter +from aleph.sdk.query.responses import Post, PostsResponse +from aleph.sdk.types import Voucher, VoucherMetadata + + +class Vouchers: + """ + This service is made to fetch voucher (SOL / EVM) + """ + + def __init__(self, client): + self._client = client + + # Utils + def _resolve_address(self, address: str) -> str: + return address # Not Authenticated client so address need to be given + + async def _fetch_voucher_update(self): + """ + Fetch the latest EVM voucher update (unfiltered). + """ + + post_filter = PostFilter( + types=["vouchers-update"], addresses=[settings.VOUCHER_ORIGIN_ADDRESS] + ) + vouchers_post: PostsResponse = await self._client.get_posts( + post_filter=post_filter, page_size=1 + ) + + if not vouchers_post.posts: + return [] + + message_post: Post = vouchers_post.posts[0] + + nft_vouchers = message_post.content.get("nft_vouchers", {}) + return list(nft_vouchers.items()) # [(voucher_id, voucher_data)] + + async def _fetch_solana_voucher_list(self): + """ + Fetch full Solana voucher registry (unfiltered). + """ + try: + async with aiohttp.ClientSession() as session: + async with session.get(settings.VOUCHER_SOL_REGISTRY) as resp: + resp.raise_for_status() + return await resp.json() + except ClientResponseError: + return {} + + async def fetch_voucher_metadata( + self, metadata_id: str + ) -> Optional[VoucherMetadata]: + """ + Fetch metadata for a given voucher. + """ + url = f"https://claim.twentysix.cloud/sbt/metadata/{metadata_id}.json" + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + resp.raise_for_status() + data = await resp.json() + return VoucherMetadata.model_validate(data) + except ClientResponseError: + return None + + async def get_solana_vouchers(self, address: str) -> list[Voucher]: + """ + Fetch Solana vouchers for a specific address. + """ + resolved_address = self._resolve_address(address=address) + vouchers: list[Voucher] = [] + + registry_data = await self._fetch_solana_voucher_list() + + claimed_tickets = registry_data.get("claimed_tickets", {}) + batches = registry_data.get("batches", {}) + + for ticket_hash, ticket_data in claimed_tickets.items(): + claimer = ticket_data.get("claimer") + if claimer != resolved_address: + continue + + batch_id = ticket_data.get("batch_id") + metadata_id = None + + if str(batch_id) in batches: + metadata_id = batches[str(batch_id)].get("metadata_id") + + if metadata_id: + metadata = await self.fetch_voucher_metadata(metadata_id) + if metadata: + voucher = Voucher( + id=ticket_hash, + metadata_id=metadata_id, + name=metadata.name, + description=metadata.description, + external_url=metadata.external_url, + image=metadata.image, + icon=metadata.icon, + attributes=metadata.attributes, + ) + vouchers.append(voucher) + + return vouchers + + async def get_evm_vouchers(self, address: str) -> list[Voucher]: + """ + Retrieve vouchers specific to EVM chains for a specific address. + """ + resolved_address = self._resolve_address(address=address) + vouchers: list[Voucher] = [] + + nft_vouchers = await self._fetch_voucher_update() + for voucher_id, voucher_data in nft_vouchers: + if voucher_data.get("claimer") != resolved_address: + continue + + metadata_id = voucher_data.get("metadata_id") + metadata = await self.fetch_voucher_metadata(metadata_id) + if not metadata: + continue + + voucher = Voucher( + id=voucher_id, + metadata_id=metadata_id, + name=metadata.name, + description=metadata.description, + external_url=metadata.external_url, + image=metadata.image, + icon=metadata.icon, + attributes=metadata.attributes, + ) + vouchers.append(voucher) + return vouchers + + async def fetch_vouchers_by_chain(self, chain: Chain, address: str): + if chain == Chain.SOL: + return await self.get_solana_vouchers(address=address) + else: + return await self.get_evm_vouchers(address=address) + + async def get_vouchers(self, address: str) -> list[Voucher]: + """ + Retrieve all vouchers for the account / specific adress, across EVM and Solana chains. + """ + vouchers = [] + + # Get EVM vouchers + if address.startswith("0x") and len(address) == 42: + evm_vouchers = await self.get_evm_vouchers(address=address) + vouchers.extend(evm_vouchers) + else: + # Get Solana vouchers + solana_vouchers = await self.get_solana_vouchers(address=address) + vouchers.extend(solana_vouchers) + + return vouchers diff --git a/src/aleph/sdk/client/vm_client.py b/src/aleph/sdk/client/vm_client.py new file mode 100644 index 00000000..f41c6ac1 --- /dev/null +++ b/src/aleph/sdk/client/vm_client.py @@ -0,0 +1,226 @@ +import datetime +import json +import logging +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from urllib.parse import urlparse + +import aiohttp +from aiohttp.client import _RequestContextManager +from aleph_message.models import Chain, ItemHash +from eth_account.messages import encode_defunct +from jwcrypto import jwk + +from aleph.sdk.chains.solana import SOLAccount +from aleph.sdk.types import Account +from aleph.sdk.utils import ( + create_vm_control_payload, + sign_vm_control_payload, + to_0x_hex, +) + +logger = logging.getLogger(__name__) + + +class VmClient: + account: Account + ephemeral_key: jwk.JWK + node_url: str + pubkey_payload: Dict[str, Any] + pubkey_signature_header: str + session: aiohttp.ClientSession + + def __init__( + self, + account: Account, + node_url: str = "", + session: Optional[aiohttp.ClientSession] = None, + ): + self.account = account + self.ephemeral_key = jwk.JWK.generate(kty="EC", crv="P-256") + self.node_url = node_url.rstrip("/") + self.pubkey_payload = self._generate_pubkey_payload( + Chain.SOL if isinstance(account, SOLAccount) else Chain.ETH + ) + self.pubkey_signature_header = "" + self.session = session or aiohttp.ClientSession() + + def _generate_pubkey_payload(self, chain: Chain = Chain.ETH) -> Dict[str, Any]: + return { + "pubkey": json.loads(self.ephemeral_key.export_public()), + "alg": "ECDSA", + "domain": self.node_domain, + "address": self.account.get_address(), + "expires": ( + datetime.datetime.utcnow() + datetime.timedelta(days=1) + ).isoformat() + + "Z", + "chain": chain.value, + } + + async def _generate_pubkey_signature_header(self) -> str: + pubkey_payload = json.dumps(self.pubkey_payload).encode("utf-8").hex() + if isinstance(self.account, SOLAccount): + buffer_to_sign = bytes(pubkey_payload, encoding="utf-8") + else: + signable_message = encode_defunct(hexstr=pubkey_payload) + buffer_to_sign = signable_message.body + + signed_message = await self.account.sign_raw(buffer_to_sign) + pubkey_signature = to_0x_hex(signed_message) + + return json.dumps( + { + "sender": self.account.get_address(), + "payload": pubkey_payload, + "signature": pubkey_signature, + "content": {"domain": self.node_domain}, + } + ) + + async def _generate_header( + self, vm_id: ItemHash, operation: str, method: str + ) -> Tuple[str, Dict[str, str]]: + payload = create_vm_control_payload( + vm_id, operation, domain=self.node_domain, method=method + ) + signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) + + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + headers = { + "X-SignedPubKey": self.pubkey_signature_header, + "X-SignedOperation": signed_operation, + } + + path = payload["path"] + return f"{self.node_url}{path}", headers + + @property + def node_domain(self) -> str: + domain = urlparse(self.node_url).hostname + if not domain: + raise Exception("Could not parse node domain") + return domain + + async def perform_operation( + self, vm_id: ItemHash, operation: str, method: str = "POST" + ) -> Tuple[Optional[int], str]: + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method=method + ) + + try: + async with self.session.request( + method=method, url=url, headers=header + ) as response: + response_text = await response.text() + return response.status, response_text + + except aiohttp.ClientError as e: + logger.error(f"HTTP error during operation {operation}: {str(e)}") + return None, str(e) + + def operate( + self, vm_id: ItemHash, operation: str, method: str = "POST" + ) -> _RequestContextManager: + """Request a CRN an operation for a VM (eg reboot, logs) + + This operation is authenticated via the user wallet. + Use as an async context manager. + e.g `async with client.operate(vm_id=item_hash, operation="logs", method="GET") as response:` + """ + + async def authenticated_request(): + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method=method + ) + resp = await self.session._request( + method=method, str_or_url=url, headers=header + ) + return resp + + return _RequestContextManager(authenticated_request()) + + async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + payload = create_vm_control_payload( + vm_id, "stream_logs", method="get", domain=self.node_domain + ) + signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) + path = payload["path"] + ws_url = f"{self.node_url}{path}" + + async with self.session.ws_connect(ws_url) as ws: + auth_message = { + "auth": { + "X-SignedPubKey": json.loads(self.pubkey_signature_header), + "X-SignedOperation": json.loads(signed_operation), + } + } + await ws.send_json(auth_message) + + async for msg in ws: # msg is of type aiohttp.WSMessage + if msg.type == aiohttp.WSMsgType.TEXT: + yield msg.data + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + async def start_instance(self, vm_id: ItemHash) -> Tuple[int, str]: + return await self.notify_allocation(vm_id) + + async def stop_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: + return await self.perform_operation(vm_id, "stop") + + async def reboot_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: + return await self.perform_operation(vm_id, "reboot") + + async def erase_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: + return await self.perform_operation(vm_id, "erase") + + async def expire_instance(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: + return await self.perform_operation(vm_id, "expire") + + async def notify_allocation(self, vm_id: ItemHash) -> Tuple[int, str]: + json_data = {"instance": vm_id} + + async with self.session.post( + f"{self.node_url}/control/allocation/notify", json=json_data + ) as session: + form_response_text = await session.text() + + return session.status, form_response_text + + async def manage_instance( + self, vm_id: ItemHash, operations: List[str] + ) -> Tuple[int, str]: + for operation in operations: + status, response = await self.perform_operation(vm_id, operation) + if status != 200 and status: + return status, response + return 200, "All operations completed successfully" + + async def close(self): + await self.session.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() diff --git a/src/aleph/sdk/client/vm_confidential_client.py b/src/aleph/sdk/client/vm_confidential_client.py new file mode 100644 index 00000000..0d9d6e18 --- /dev/null +++ b/src/aleph/sdk/client/vm_confidential_client.py @@ -0,0 +1,216 @@ +import base64 +import json +import logging +import os +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import aiohttp +from aleph_message.models import ItemHash + +from aleph.sdk.client.vm_client import VmClient +from aleph.sdk.types import Account, SEVMeasurement +from aleph.sdk.utils import ( + compute_confidential_measure, + encrypt_secret_table, + get_vm_measure, + make_packet_header, + make_secret_table, + run_in_subprocess, +) + +logger = logging.getLogger(__name__) + + +class VmConfidentialClient(VmClient): + sevctl_path: Path + + def __init__( + self, + account: Account, + sevctl_path: Path, + node_url: str = "", + session: Optional[aiohttp.ClientSession] = None, + ): + super().__init__(account, node_url, session) + self.sevctl_path = sevctl_path + + async def get_certificates(self) -> Tuple[Optional[int], str]: + """ + Get platform confidential certificate + """ + + url = f"{self.node_url}/about/certificates" + try: + async with self.session.get(url) as response: + data = await response.read() + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + tmp_file.write(data) + return response.status, tmp_file.name + + except aiohttp.ClientError as e: + logger.error( + f"HTTP error getting node certificates on {self.node_url}: {str(e)}" + ) + return None, str(e) + + async def create_session( + self, certificate_prefix: str, platform_certificate_path: Path, policy: int + ) -> Path: + """ + Create new confidential session + """ + + current_path = Path().cwd() + args = [ + "session", + "--name", + certificate_prefix, + str(platform_certificate_path), + str(policy), + ] + try: + # TODO: Check command result + await self.sevctl_cmd(*args) + return current_path + except Exception as e: + raise ValueError(f"Session creation have failed, reason: {str(e)}") + + async def initialize(self, vm_id: ItemHash, session: Path, godh: Path) -> str: + """ + Initialize Confidential VM negociation passing the needed session files + """ + + session_file = session.read_bytes() + godh_file = godh.read_bytes() + params = { + "session": session_file, + "godh": godh_file, + } + return await self.perform_confidential_operation( + vm_id, "confidential/initialize", params=params + ) + + async def measurement(self, vm_id: ItemHash) -> SEVMeasurement: + """ + Fetch VM confidential measurement + """ + + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + status, text = await self.perform_operation( + vm_id, "confidential/measurement", method="GET" + ) + sev_measurement = SEVMeasurement.model_validate_json(text) + return sev_measurement + + async def validate_measure( + self, sev_data: SEVMeasurement, tik_path: Path, firmware_hash: str + ) -> bool: + """ + Validate VM confidential measurement + """ + + tik = tik_path.read_bytes() + vm_measure, nonce = get_vm_measure(sev_data) + + expected_measure = compute_confidential_measure( + sev_info=sev_data.sev_info, + tik=tik, + expected_hash=firmware_hash, + nonce=nonce, + ).digest() + return expected_measure == vm_measure + + async def build_secret( + self, tek_path: Path, tik_path: Path, sev_data: SEVMeasurement, secret: str + ) -> Tuple[str, str]: + """ + Build disk secret to be injected on the confidential VM + """ + + tek = tek_path.read_bytes() + tik = tik_path.read_bytes() + + vm_measure, _ = get_vm_measure(sev_data) + + iv = os.urandom(16) + secret_table = make_secret_table(secret) + encrypted_secret_table = encrypt_secret_table( + secret_table=secret_table, tek=tek, iv=iv + ) + + packet_header = make_packet_header( + vm_measure=vm_measure, + encrypted_secret_table=encrypted_secret_table, + secret_table_size=len(secret_table), + tik=tik, + iv=iv, + ) + + encoded_packet_header = base64.b64encode(packet_header).decode() + encoded_secret = base64.b64encode(encrypted_secret_table).decode() + + return encoded_packet_header, encoded_secret + + async def inject_secret( + self, vm_id: ItemHash, packet_header: str, secret: str + ) -> Dict: + """ + Send the secret by the encrypted channel to boot up the VM + """ + + params = { + "packet_header": packet_header, + "secret": secret, + } + text = await self.perform_confidential_operation( + vm_id, "confidential/inject_secret", json=params + ) + + return json.loads(text) + + async def perform_confidential_operation( + self, + vm_id: ItemHash, + operation: str, + params: Optional[Dict[str, Any]] = None, + json=None, + ) -> str: + """ + Send confidential operations to the CRN passing the auth headers on each request + """ + + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method="post" + ) + + try: + async with self.session.post( + url, headers=header, data=params, json=json + ) as response: + response.raise_for_status() + response_text = await response.text() + return response_text + + except aiohttp.ClientError as e: + raise ValueError(f"HTTP error during operation {operation}: {str(e)}") + + async def sevctl_cmd(self, *args) -> bytes: + """ + Execute `sevctl` command with given arguments + """ + + return await run_in_subprocess( + [str(self.sevctl_path), *args], + check=True, + ) diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index 318536e4..ee91fc39 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -1,19 +1,34 @@ +import json +import logging import os +from enum import Enum from pathlib import Path from shutil import which -from typing import Optional +from typing import ClassVar, Dict, List, Optional, Union -from pydantic import BaseSettings, Field +from aleph_message.models import Chain +from aleph_message.models.execution.environment import HypervisorType +from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + +from aleph.sdk.types import ChainInfo + +logger = logging.getLogger(__name__) class Settings(BaseSettings): CONFIG_HOME: Optional[str] = None + CONFIG_FILE: Path = Field( + default=Path("config.json"), + description="Path to the JSON file containing chain account configurations", + ) + # In case the user does not want to bother with handling private keys himself, # do an ugly and insecure write and read from disk to this file. PRIVATE_KEY_FILE: Path = Field( default=Path("ethereum.key"), - description="Path to the private key used to sign messages", + description="Path to the private key used to sign messages and transactions", ) PRIVATE_MNEMONIC_FILE: Path = Field( @@ -28,27 +43,322 @@ class Settings(BaseSettings): REMOTE_CRYPTO_HOST: Optional[str] = None REMOTE_CRYPTO_UNIX_SOCKET: Optional[str] = None ADDRESS_TO_USE: Optional[str] = None + HTTP_REQUEST_TIMEOUT: ClassVar[float] = 15.0 + + DEFAULT_CHANNEL: str = "ALEPH-CLOUDSOLUTIONS" + # Firecracker runtime for programs DEFAULT_RUNTIME_ID: str = ( - "f873715dc2feec3833074bd4b8745363a0e0093746b987b4c8191268883b2463" # Debian 12 official runtime + "63f07193e6ee9d207b7d1fcf8286f9aee34e6f12f101d2ec77c1229f92964696" + ) + + # Qemu rootfs for instances + DEBIAN_12_QEMU_ROOTFS_ID: str = ( + "b6ff5c3a8205d1ca4c7c3369300eeafff498b558f71b851aa2114afd0a532717" + ) + UBUNTU_22_QEMU_ROOTFS_ID: str = ( + "4a0f62da42f4478544616519e6f5d58adb1096e069b392b151d47c3609492d0c" + ) + UBUNTU_24_QEMU_ROOTFS_ID: str = ( + "5330dcefe1857bcd97b7b7f24d1420a7d46232d53f27be280c8a7071d88bd84e" + ) + + DEFAULT_CONFIDENTIAL_FIRMWARE: str = ( + "ba5bb13f3abca960b101a759be162b229e2b7e93ecad9d1307e54de887f177ff" + ) + DEFAULT_CONFIDENTIAL_FIRMWARE_HASH: str = ( + "89b76b0e64fe9015084fbffdf8ac98185bafc688bfe7a0b398585c392d03c7ee" ) + + DEFAULT_ROOTFS_SIZE: int = 20_480 + DEFAULT_INSTANCE_MEMORY: int = 2_048 + DEFAULT_HYPERVISOR: HypervisorType = HypervisorType.qemu + DEFAULT_VM_MEMORY: int = 256 DEFAULT_VM_VCPUS: int = 1 DEFAULT_VM_TIMEOUT: float = 30.0 CODE_USES_SQUASHFS: bool = which("mksquashfs") is not None # True if command exists + VM_URL_PATH: ClassVar[str] = "https://aleph.sh/vm/{hash}" + VM_URL_HOST: ClassVar[str] = "https://{hash_base32}.aleph.sh" + IPFS_GATEWAY: ClassVar[str] = "https://ipfs.aleph.cloud/ipfs/" + CRN_URL_FOR_PROGRAMS: ClassVar[str] = "https://dchq.staging.aleph.sh/" + + DNS_API: ClassVar[str] = "https://api.dns.public.aleph.sh/instances/list" + CRN_URL_UPDATE: ClassVar[str] = "{crn_url}/control/machine/{vm_hash}/update" + CRN_LIST_URL: str = "https://crns-list.aleph.sh/crns.json" + CRN_VERSION_URL: ClassVar[str] = ( + "https://api.github.com/repos/aleph-im/aleph-vm/releases/latest" + ) + SCHEDULER_URL: ClassVar[str] = "https://scheduler.api.aleph.cloud/" + + VOUCHER_METDATA_TEMPLATE_URL: str = ( + "https://claim.twentysix.cloud/sbt/metadata/{}.json" + ) + VOUCHER_SOL_REGISTRY: str = "https://api.claim.twentysix.cloud/v1/registry/sol" + VOUCHER_ORIGIN_ADDRESS: str = "0xB34f25f2c935bCA437C061547eA12851d719dEFb" + + ALEPH_AGGREGATE_ADDRESS: str = "0xFba561a84A537fCaa567bb7A2257e7142701ae2A" + + # Web3Provider settings + TOKEN_DECIMALS: ClassVar[int] = 18 + TX_TIMEOUT: ClassVar[int] = 60 * 3 + CHAINS: Dict[Union[Chain, str], ChainInfo] = { + # TESTNETS + "SEPOLIA": ChainInfo( + chain_id=11155111, + rpc="https://eth-sepolia.public.blastapi.io", + token="0xc4bf5cbdabe595361438f8c6a187bdc330539c60", + super_token="0x22064a21fee226d8ffb8818e7627d5ff6d0fc33a", + active=False, + ), + # MAINNETS + Chain.ARBITRUM: ChainInfo( + chain_id=42161, + rpc="https://arbitrum-one.publicnode.com", + ), + Chain.AURORA: ChainInfo( + chain_id=1313161554, + rpc="https://mainnet.aurora.dev", + ), + Chain.AVAX: ChainInfo( + chain_id=43114, + rpc="https://api.avax.network/ext/bc/C/rpc", + token="0xc0Fbc4967259786C743361a5885ef49380473dCF", + super_token="0xc0Fbc4967259786C743361a5885ef49380473dCF", + ), + Chain.BASE: ChainInfo( + chain_id=8453, + rpc="https://base-mainnet.public.blastapi.io", + token="0xc0Fbc4967259786C743361a5885ef49380473dCF", + super_token="0xc0Fbc4967259786C743361a5885ef49380473dCF", + ), + Chain.BLAST: ChainInfo( + chain_id=81457, + rpc="https://blastl2-mainnet.public.blastapi.io", + ), + Chain.BOB: ChainInfo( + chain_id=60808, + rpc="https://bob-mainnet.public.blastapi.io", + ), + Chain.BSC: ChainInfo( + chain_id=56, + rpc="https://binance.llamarpc.com", + token="0x82D2f8E02Afb160Dd5A480a617692e62de9038C4", + active=False, + ), + Chain.CYBER: ChainInfo( + chain_id=7560, + rpc="https://rpc.cyber.co", + ), + Chain.ETH: ChainInfo( + chain_id=1, + rpc="https://eth-mainnet.public.blastapi.io", + token="0x27702a26126e0B3702af63Ee09aC4d1A084EF628", + ), + Chain.ETHERLINK: ChainInfo( + chain_id=42793, + rpc="https://node.mainnet.etherlink.com", + ), + Chain.FRAXTAL: ChainInfo( + chain_id=252, + rpc="https://rpc.frax.com", + ), + Chain.HYPE: ChainInfo( + chain_id=999, + rpc="https://rpc.hyperliquid.xyz/evm", + ), + Chain.INK: ChainInfo( + chain_id=57073, + rpc="https://rpc-gel.inkonchain.com", + ), + Chain.LENS: ChainInfo( + chain_id=232, + rpc="https://rpc.lens.xyz", + ), + Chain.LINEA: ChainInfo( + chain_id=59144, + rpc="https://linea-rpc.publicnode.com", + ), + Chain.LISK: ChainInfo( + chain_id=1135, + rpc="https://rpc.api.lisk.com", + ), + Chain.METIS: ChainInfo( + chain_id=1088, + rpc="https://metis.drpc.org", + ), + Chain.MODE: ChainInfo( + chain_id=34443, + rpc="https://mode.drpc.org", + ), + Chain.OPTIMISM: ChainInfo( + chain_id=10, + rpc="https://optimism-rpc.publicnode.com", + ), + Chain.POL: ChainInfo( + chain_id=137, + rpc="https://polygon.gateway.tenderly.co", + ), + Chain.SOMNIA: ChainInfo( + chain_id=50312, + rpc="https://dream-rpc.somnia.network", + ), + Chain.SONIC: ChainInfo( + chain_id=146, + rpc="https://rpc.soniclabs.com", + ), + Chain.UNICHAIN: ChainInfo( + chain_id=130, + rpc="https://mainnet.unichain.org", + ), + Chain.WORLDCHAIN: ChainInfo( + chain_id=480, + rpc="https://worldchain-mainnet.gateway.tenderly.co", + ), + Chain.ZORA: ChainInfo( + chain_id=7777777, + rpc="https://rpc.zora.energy/", + ), + } + # Add all placeholders to allow easy dynamic setup of CHAINS + CHAINS_SEPOLIA_ACTIVE: Optional[bool] = None + CHAINS_ETH_ACTIVE: Optional[bool] = None + CHAINS_AVAX_ACTIVE: Optional[bool] = None + CHAINS_BASE_ACTIVE: Optional[bool] = None + CHAINS_BSC_ACTIVE: Optional[bool] = None + CHAINS_ARBITRUM_ACTIVE: Optional[bool] = None + CHAINS_AURORA_ACTIVE: Optional[bool] = None + CHAINS_BLAST_ACTIVE: Optional[bool] = None + CHAINS_BOB_ACTIVE: Optional[bool] = None + CHAINS_CYBER_ACTIVE: Optional[bool] = None + CHAINS_ETHERLINK_ACTIVE: Optional[bool] = None + CHAINS_FRAXTAL_ACTIVE: Optional[bool] = None + CHAINS_HYPE_ACTIVE: Optional[bool] = None + CHAINS_LENS_ACTIVE: Optional[bool] = None + CHAINS_LINEA_ACTIVE: Optional[bool] = None + CHAINS_LISK_ACTIVE: Optional[bool] = None + CHAINS_METIS_ACTIVE: Optional[bool] = None + CHAINS_MODE_ACTIVE: Optional[bool] = None + CHAINS_OPTIMISM_ACTIVE: Optional[bool] = None + CHAINS_POL_ACTIVE: Optional[bool] = None + CHAINS_SOMNIA_ACTIVE: Optional[bool] = None + CHAINS_SONIC_ACTIVE: Optional[bool] = None + CHAINS_UNICHAIN_ACTIVE: Optional[bool] = None + CHAINS_WORLDCHAIN_ACTIVE: Optional[bool] = None + CHAINS_ZORA_ACTIVE: Optional[bool] = None + + CHAINS_SEPOLIA_RPC: Optional[str] = None + CHAINS_ETH_RPC: Optional[str] = None + CHAINS_AVAX_RPC: Optional[str] = None + CHAINS_BASE_RPC: Optional[str] = None + CHAINS_BSC_RPC: Optional[str] = None + CHAINS_ARBITRUM_RPC: Optional[str] = None + CHAINS_AURORA_RPC: Optional[str] = None + CHAINS_BLAST_RPC: Optional[str] = None + CHAINS_BOB_RPC: Optional[str] = None + CHAINS_CYBER_RPC: Optional[str] = None + CHAINS_ETHERLINK_RPC: Optional[str] = None + CHAINS_FRAXTAL_RPC: Optional[str] = None + CHAINS_HYPE_RPC: Optional[str] = None + CHAINS_LENS_RPC: Optional[str] = None + CHAINS_LINEA_RPC: Optional[str] = None + CHAINS_LISK_RPC: Optional[str] = None + CHAINS_METIS_RPC: Optional[str] = None + CHAINS_MODE_RPC: Optional[str] = None + CHAINS_OPTIMISM_RPC: Optional[str] = None + CHAINS_POL_RPC: Optional[str] = None + CHAINS_SOMNIA_RPC: Optional[str] = None + CHAINS_SONIC_RPC: Optional[str] = None + CHAINS_UNICHAIN_RPC: Optional[str] = None + CHAINS_WORLDCHAIN_RPC: Optional[str] = None + CHAINS_ZORA_RPC: Optional[str] = None + + DEFAULT_CHAIN: Chain = Chain.ETH + # Dns resolver - DNS_IPFS_DOMAIN = "ipfs.public.aleph.sh" - DNS_PROGRAM_DOMAIN = "program.public.aleph.sh" - DNS_INSTANCE_DOMAIN = "instance.public.aleph.sh" - DNS_STATIC_DOMAIN = "static.public.aleph.sh" - DNS_RESOLVERS = ["9.9.9.9", "1.1.1.1"] + DNS_IPFS_DOMAIN: ClassVar[str] = "ipfs.public.aleph.sh" + DNS_PROGRAM_DOMAIN: ClassVar[str] = "program.public.aleph.sh" + DNS_INSTANCE_DOMAIN: ClassVar[str] = "instance.public.aleph.sh" + DNS_STATIC_DOMAIN: ClassVar[str] = "static.public.aleph.sh" + DNS_RESOLVERS: ClassVar[List[str]] = ["9.9.9.9", "1.1.1.1"] - class Config: - env_prefix = "ALEPH_" - case_sensitive = False - env_file = ".env" + model_config = SettingsConfigDict( + env_prefix="ALEPH_", case_sensitive=False, env_file=".env", extra="ignore" + ) + + +class AccountType(str, Enum): + IMPORTED: str = "imported" + HARDWARE: str = "hardware" + + +class MainConfiguration(BaseModel): + """ + Intern Chain Management with Account. + """ + + path: Optional[Path] = None + type: AccountType = AccountType.IMPORTED + chain: Chain + address: Optional[str] = None + derivation_path: Optional[str] = None + model_config = SettingsConfigDict(use_enum_values=True) + + @field_validator("type", mode="before") + def normalize_type(cls, v): + """Handle legacy 'internal'/'external' and accept both strings or enums.""" + if v is None: + return v + if isinstance(v, AccountType): + return v + v_str = str(v).lower().strip() + if v_str == "internal": + return AccountType.IMPORTED + elif v_str == "external": + return AccountType.HARDWARE + elif v_str in ("imported", "hardware"): + return AccountType(v_str) + raise ValueError(f"Unknown account type: {v}") + + @model_validator(mode="before") + def infer_type(cls, values: dict): + """ + Previously, the `type` field was optional to maintain backward compatibility + for users with older configurations (e.g., using a private key). + + We now enforce `type` as required, but still handle legacy cases where it may + be missing by inferring its value automatically. + + Inference logic: + - If `type` is explicitly set, it is left unchanged. + - If `type` is missing: + - If `path` is provided → assume `imported` + - If only `address` is provided → assume `hardware` (Ledger) + (This scenario should not normally occur, but is handled for safety.) + - If both `path` and `address` are present → trust `path` (imported) + """ + + t = values.get("type") + path = values.get("path") + address = values.get("address") + + # If type already given , keep it + if t is not None: + return values + + # Infer if missing + if path: + values["type"] = AccountType.IMPORTED + elif address: + values["type"] = AccountType.HARDWARE + else: + raise ValueError( + "Cannot infer account type: please provide 'type', or 'path' (imported), or 'address' (hardware)." + ) + + return values # Settings singleton @@ -74,3 +384,62 @@ class Config: settings.PRIVATE_MNEMONIC_FILE = Path( settings.CONFIG_HOME, "private-keys", "substrate.mnemonic" ) +if str(settings.CONFIG_FILE) == "config.json": + settings.CONFIG_FILE = Path(settings.CONFIG_HOME, "config.json") + # If Config file exist and well filled we update the PRIVATE_KEY_FILE default + if settings.CONFIG_FILE.exists(): + try: + with open(settings.CONFIG_FILE, "r", encoding="utf-8") as f: + config_data = json.load(f) + + if "path" in config_data and ( + "type" not in config_data or config_data["type"] == AccountType.IMPORTED + ): + settings.PRIVATE_KEY_FILE = Path(config_data["path"]) + except json.JSONDecodeError: + pass + + +# Update CHAINS settings and remove placeholders +CHAINS_ENV = [(key[7:], value) for key, value in settings if key.startswith("CHAINS_")] +for fields, value in CHAINS_ENV: + if value: + chain, field = fields.split("_", 1) + chain = chain if chain not in Chain.__members__ else Chain[chain] + field = field.lower() + settings.CHAINS[chain].__dict__[field] = value + settings.__delattr__(f"CHAINS_{fields}") + + +def save_main_configuration(file_path: Path, data: MainConfiguration): + """ + Synchronously save a single ChainAccount object as JSON to a file. + """ + with file_path.open("w") as file: + data_serializable = data.model_dump() + if ( + data_serializable["path"] is not None + ): # Avoid having path : "None" in config file + data_serializable["path"] = str(data_serializable["path"]) + json.dump(data_serializable, file, indent=4) + + +def load_main_configuration(file_path: Path) -> Optional[MainConfiguration]: + """ + Synchronously load the private key and chain type from a file. + If the file does not exist or is empty, return None. + """ + if not file_path.exists() or file_path.stat().st_size == 0: + logger.debug(f"File {file_path} does not exist or is empty. Returning None.") + return None + + try: + with file_path.open("rb") as file: + content = file.read() + return MainConfiguration.model_validate_json(content.decode("utf-8")) + except UnicodeDecodeError as e: + logger.error(f"Unable to decode {file_path} as UTF-8: {e}") + except json.JSONDecodeError: + logger.error(f"Invalid JSON format in {file_path}.") + + return None diff --git a/src/aleph/sdk/connectors/superfluid.py b/src/aleph/sdk/connectors/superfluid.py new file mode 100644 index 00000000..2d12080f --- /dev/null +++ b/src/aleph/sdk/connectors/superfluid.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from decimal import Decimal +from typing import TYPE_CHECKING, Optional + +from superfluid import CFA_V1, Operation, Web3FlowInfo +from web3 import Web3 +from web3.exceptions import ContractCustomError + +from aleph.sdk.evm_utils import ( + FlowUpdate, + from_wei_token, + get_super_token_address, + to_wei_token, +) +from aleph.sdk.exceptions import InsufficientFundsError +from aleph.sdk.types import TokenType + +if TYPE_CHECKING: + from aleph.sdk.chains.ethereum import BaseEthAccount + + +class Superfluid: + """ + Wrapper around the Superfluid APIs in order to CRUD Superfluid flows between two accounts. + """ + + account: BaseEthAccount + normalized_address: str + super_token: str + cfaV1Instance: CFA_V1 + MIN_4_HOURS = 60 * 60 * 4 + + def __init__(self, account: BaseEthAccount): + self.account = account + self.normalized_address = Web3.to_checksum_address(account.get_address()) + if account.chain: + self.super_token = str(get_super_token_address(account.chain)) + self.cfaV1Instance = CFA_V1(account.rpc, account.chain_id) + + # Helpers Functions + def _get_populated_transaction_request(self, operation, rpc: str): + """ + Prepares the transaction to be signed by either imported / hardware wallets + @param operation - on chain operations + @param rpc - RPC URL + @param address - address from Ledger account + @returns - TxParams - The transaction object + """ + + call = ( + operation.forwarder_call + if operation.forwarder_call is not None + else operation.agreement_call + ) + populated_transaction = call.build_transaction( + {"from": self.normalized_address} + ) + + web3 = Web3(Web3.HTTPProvider(rpc)) + nonce = web3.eth.get_transaction_count(self.normalized_address) + + populated_transaction["nonce"] = nonce + return populated_transaction + + def _simulate_create_tx_flow(self, flow: Decimal, block=True) -> bool: + try: + operation = self.cfaV1Instance.create_flow( + sender=self.normalized_address, + receiver=Web3.to_checksum_address( + "0x0000000000000000000000000000000000000001" + ), # Fake Address we do not sign/send this transactions + super_token=self.super_token, + flow_rate=int(to_wei_token(flow)), + ) + if not self.account.rpc: + raise ValueError( + f"RPC endpoint is required but not set for this chain {self.account.chain}." + ) + populated_transaction = self._get_populated_transaction_request( + operation=operation, + rpc=self.account.rpc, + ) + return self.account.can_transact(tx=populated_transaction, block=block) + except ContractCustomError as e: + if getattr(e, "data", None) == "0xea76c9b3": + balance = self.account.get_super_token_balance() + MIN_FLOW_4H = to_wei_token(flow) * Decimal(self.MIN_4_HOURS) + raise InsufficientFundsError( + token_type=TokenType.ALEPH, + required_funds=float(from_wei_token(MIN_FLOW_4H)), + available_funds=float(from_wei_token(balance)), + ) + return False + + async def _execute_operation_with_account(self, operation: Operation) -> str: + """ + Execute an operation using the provided account + @param operation - Operation instance from the library + @returns - str - Transaction hash + """ + if not self.account.rpc: + raise ValueError( + f"RPC endpoint is required but not set for this chain {self.account.chain}." + ) + + populated_transaction = self._get_populated_transaction_request( + operation=operation, rpc=self.account.rpc + ) + self.account.can_transact(tx=populated_transaction) + + return await self.account._sign_and_send_transaction(populated_transaction) + + def can_start_flow(self, flow: Decimal, block=True) -> bool: + """Check if the account has enough funds to start a Superfluid flow of the given size.""" + return self._simulate_create_tx_flow(flow=flow, block=block) + + async def create_flow(self, receiver: str, flow: Decimal) -> str: + """Create a Superfluid flow between two addresses.""" + return await self._execute_operation_with_account( + operation=self.cfaV1Instance.create_flow( + sender=self.normalized_address, + receiver=Web3.to_checksum_address(receiver), + super_token=self.super_token, + flow_rate=int(to_wei_token(flow)), + ), + ) + + async def get_flow(self, sender: str, receiver: str) -> Web3FlowInfo: + """Fetch information about the Superfluid flow between two addresses.""" + return self.cfaV1Instance.get_flow( + sender=Web3.to_checksum_address(sender), + receiver=Web3.to_checksum_address(receiver), + super_token=self.super_token, + ) + + async def delete_flow(self, receiver: str) -> str: + """Delete the Supefluid flow between two addresses.""" + return await self._execute_operation_with_account( + operation=self.cfaV1Instance.delete_flow( + sender=self.normalized_address, + receiver=Web3.to_checksum_address(receiver), + super_token=self.super_token, + ), + ) + + async def update_flow(self, receiver: str, flow: Decimal) -> str: + """Update the flow of a Superfluid flow between two addresses.""" + return await self._execute_operation_with_account( + operation=self.cfaV1Instance.update_flow( + sender=self.normalized_address, + receiver=Web3.to_checksum_address(receiver), + super_token=self.super_token, + flow_rate=int(to_wei_token(flow)), + ), + ) + + async def manage_flow( + self, + receiver: str, + flow: Decimal, + update_type: FlowUpdate, + ) -> Optional[str]: + """ + Update the flow of a Superfluid stream between a sender and receiver. + This function either increases or decreases the flow rate between the sender and receiver, + based on the update_type. If no flow exists and the update type is augmentation, it creates a new flow + with the specified rate. If the update type is reduction and the reduction amount brings the flow to zero + or below, the flow is deleted. + + :param receiver: Address of the receiver in hexadecimal format. + :param flow: The flow rate to be added or removed (in ether). + :param update_type: The type of update to perform (augmentation or reduction). + :return: The transaction hash of the executed operation (create, update, or delete flow). + """ + + # Retrieve current flow info + flow_info: Web3FlowInfo = await self.account.get_flow(receiver) + + current_flow_rate_wei: Decimal = Decimal(flow_info["flowRate"] or 0) + flow_rate_wei: int = int(to_wei_token(flow)) + + if update_type == FlowUpdate.INCREASE: + if current_flow_rate_wei > 0: + # Update existing flow by increasing the rate + new_flow_rate_wei = current_flow_rate_wei + flow_rate_wei + new_flow_rate_ether = from_wei_token(new_flow_rate_wei) + return await self.account.update_flow(receiver, new_flow_rate_ether) + else: + # Create a new flow if none exists + return await self.account.create_flow(receiver, flow) + else: + if current_flow_rate_wei > 0: + # Reduce the existing flow + new_flow_rate_wei = current_flow_rate_wei - flow_rate_wei + # Ensure to not leave infinitesimal flows + # Often, there were 1-10 wei remaining in the flow rate, which prevented the flow from being deleted + if new_flow_rate_wei > 99: + new_flow_rate_ether = from_wei_token(new_flow_rate_wei) + return await self.account.update_flow(receiver, new_flow_rate_ether) + else: + # Delete the flow if the new flow rate is zero or negative + return await self.account.delete_flow(receiver) + return None diff --git a/src/aleph/sdk/domain.py b/src/aleph/sdk/domain.py index a8f3fd82..70a53b08 100644 --- a/src/aleph/sdk/domain.py +++ b/src/aleph/sdk/domain.py @@ -5,6 +5,7 @@ from urllib.parse import urlparse import aiodns +import tldextract from pydantic import BaseModel, HttpUrl from .conf import settings @@ -52,11 +53,11 @@ def raise_error(self, status: Dict[str, bool]): def hostname_from_url(url: Union[HttpUrl, str]) -> Hostname: """Extract FQDN from url""" - parsed = urlparse(url) + parsed = urlparse(str(url)) if all([parsed.scheme, parsed.netloc]) is True: url = parsed.netloc - return Hostname(url) + return Hostname(str(url)) async def get_target_type(fqdn: Hostname) -> Optional[TargetType]: @@ -198,6 +199,13 @@ async def check_domain( record_type = dns_rule.dns["type"] record_value = dns_rule.dns["value"] + if record_type == "alias": + # ALIAS records cannot be reliably validated via DNS since the + # provider resolves them to A records asynchronously. Consider + # the rule as valid and trust the user's configuration. + status[dns_rule.name] = True + continue + try: entries = await resolver.query(record_name, record_type.upper()) except aiodns.error.DNSError: @@ -207,7 +215,10 @@ async def check_domain( if entries: if record_type == "txt": for entry in entries: - if hasattr(entry, "text") and entry.text == record_value: + if ( + hasattr(entry, "text") + and str(entry.text).lower() == str(record_value).lower() + ): status[dns_rule.name] = True break elif ( @@ -246,19 +257,35 @@ def get_required_dns_rules( elif target == TargetType.INSTANCE: cname_value = f"{hostname}.{settings.DNS_INSTANCE_DOMAIN}" - # cname rule - dns_rules.append( - DNSRule( - name="cname", - dns={ - "type": "cname", - "name": hostname, - "value": cname_value, - }, - info=f"Create a CNAME record for {hostname} with value {cname_value}", - on_error=f"CNAME record not found: {hostname}", + # cname or alias rule + if self.is_root_domain(hostname): + record_type = "alias" + dns_rules.append( + DNSRule( + name=record_type, + dns={ + "type": record_type, + "name": hostname, + "value": cname_value, + }, + info=f"Create an ALIAS record for {hostname} with value {cname_value}", + on_error=f"ALIAS record not found: {hostname}", + ) + ) + else: + record_type = "cname" + dns_rules.append( + DNSRule( + name=record_type, + dns={ + "type": record_type, + "name": hostname, + "value": cname_value, + }, + info=f"Create a CNAME record for {hostname} with value {cname_value}", + on_error=f"CNAME record not found: {hostname}", + ) ) - ) if target == TargetType.IPFS: # ipfs rule @@ -291,3 +318,8 @@ def get_required_dns_rules( ) return dns_rules + + @staticmethod + def is_root_domain(hostname: Hostname) -> bool: + extracted = tldextract.extract(hostname) + return bool(extracted.domain) and not extracted.subdomain diff --git a/src/aleph/sdk/evm_utils.py b/src/aleph/sdk/evm_utils.py new file mode 100644 index 00000000..62cb902b --- /dev/null +++ b/src/aleph/sdk/evm_utils.py @@ -0,0 +1,112 @@ +from decimal import ROUND_CEILING, Context, Decimal +from enum import Enum +from typing import List, Optional, Union + +from aleph_message.models import Chain +from eth_utils import to_wei +from web3 import Web3 +from web3.types import ChecksumAddress + +from .conf import settings + +MIN_ETH_BALANCE: float = 0.001 +MIN_ETH_BALANCE_WEI = Decimal(to_wei(MIN_ETH_BALANCE, "ether")) +BALANCEOF_ABI = """[{ + "name": "balanceOf", + "inputs": [{"name": "account", "type": "address"}], + "outputs": [{"name": "balance", "type": "uint256"}], + "constant": true, + "payable": false, + "stateMutability": "view", + "type": "function" +}]""" + + +class FlowUpdate(str, Enum): + REDUCE = "reduce" + INCREASE = "increase" + + +def ether_rounding(amount: Decimal) -> Decimal: + """Rounds the given value to 18 decimals.""" + return amount.quantize( + Decimal(1) / Decimal(10**18), rounding=ROUND_CEILING, context=Context(prec=36) + ) + + +def from_wei_token(amount: Decimal) -> Decimal: + """Converts the given wei value to ether.""" + return ether_rounding(amount / Decimal(10) ** Decimal(settings.TOKEN_DECIMALS)) + + +def to_wei_token(amount: Decimal) -> Decimal: + """Converts the given ether value to wei.""" + return Decimal(int(amount * Decimal(10) ** Decimal(settings.TOKEN_DECIMALS))) + + +def get_chain_id(chain: Union[Chain, str, None]) -> Optional[int]: + """Returns the CHAIN_ID of a given EVM blockchain""" + if chain: + if chain in settings.CHAINS and settings.CHAINS[chain].chain_id: + return settings.CHAINS[chain].chain_id + else: + raise ValueError(f"Unknown RPC for chain {chain}") + return None + + +def get_rpc(chain: Union[Chain, str, None]) -> Optional[str]: + """Returns the RPC to use for a given EVM blockchain""" + if chain: + if chain in settings.CHAINS and settings.CHAINS[chain].rpc: + return settings.CHAINS[chain].rpc + else: + raise ValueError(f"Unknown RPC for chain {chain}") + return None + + +def get_token_address(chain: Union[Chain, str, None]) -> Optional[ChecksumAddress]: + if chain: + if chain in settings.CHAINS: + address = settings.CHAINS[chain].super_token + if address: + try: + return Web3.to_checksum_address(address) + except ValueError: + raise ValueError(f"Invalid token address {address}") + else: + raise ValueError(f"Unknown token for chain {chain}") + return None + + +def get_super_token_address( + chain: Union[Chain, str, None] +) -> Optional[ChecksumAddress]: + if chain: + if chain in settings.CHAINS: + address = settings.CHAINS[chain].super_token + if address: + try: + return Web3.to_checksum_address(address) + except ValueError: + raise ValueError(f"Invalid token address {address}") + else: + raise ValueError(f"Unknown super_token for chain {chain}") + return None + + +def get_compatible_chains() -> List[Union[Chain, str]]: + return [chain for chain, info in settings.CHAINS.items() if info.active] + + +def get_chains_with_holding() -> List[Union[Chain, str]]: + return [ + chain for chain, info in settings.CHAINS.items() if info.active and info.token + ] + + +def get_chains_with_super_token() -> List[Union[Chain, str]]: + return [ + chain + for chain, info in settings.CHAINS.items() + if info.active and info.super_token + ] diff --git a/src/aleph/sdk/exceptions.py b/src/aleph/sdk/exceptions.py index 39972f7f..c960f5a8 100644 --- a/src/aleph/sdk/exceptions.py +++ b/src/aleph/sdk/exceptions.py @@ -1,5 +1,10 @@ from abc import ABC +from aleph_message.status import MessageStatus + +from .types import TokenType +from .utils import displayable_amount + class QueryError(ABC, ValueError): """The result of an API query is inconsistent.""" @@ -19,6 +24,74 @@ class MultipleMessagesError(QueryError): pass +class MessageNotProcessed(Exception): + """ + The resources that you arte trying to interact is not processed + """ + + item_hash: str + status: MessageStatus + + def __init__( + self, + item_hash: str, + status: MessageStatus, + ): + self.item_hash = item_hash + self.status = status + super().__init__( + f"Resources {item_hash} is not processed : {self.status.value}" + ) + + +class NotAuthorize(Exception): + """ + Request not authorize, this could happens for exemple in Ports Forwarding + if u try to setup ports for a vm who is not yours + """ + + item_hash: str + target_address: str + current_address: str + + def __init__(self, item_hash: str, target_address, current_address): + self.item_hash = item_hash + self.target_address = target_address + self.current_address = current_address + super().__init__( + f"Operations not authorize on resources {self.item_hash} \nTarget address : {self.target_address} \nCurrent address : {self.current_address}" + ) + + +class VmNotFoundOnHost(Exception): + """ + The VM not found on the host, + The Might might not be processed yet / wrong CRN_URL + """ + + item_hash: str + crn_url: str + + def __init__( + self, + item_hash: str, + crn_url, + ): + self.item_hash = item_hash + self.crn_url = crn_url + + super().__init__(f"Vm : {self.item_hash} not found on crn : {self.crn_url}") + + +class MethodNotAvailableOnCRN(Exception): + """ + If this error appears that means CRN you trying to interact is outdated and does + not handle this feature + """ + + pass + + class BroadcastError(Exception): """ Data could not be broadcast to the aleph.im network. @@ -66,15 +139,37 @@ class ForgottenMessageError(QueryError): pass +class RemovedMessageError(QueryError): + """The requested message was removed""" + + pass + + +class ResourceNotFoundError(QueryError): + """A message resource was expected but could not be found.""" + + pass + + class InsufficientFundsError(Exception): """Raised when the account does not have enough funds to perform an action""" + token_type: TokenType required_funds: float available_funds: float - def __init__(self, required_funds: float, available_funds: float): + def __init__( + self, token_type: TokenType, required_funds: float, available_funds: float + ): + self.token_type = token_type self.required_funds = required_funds self.available_funds = available_funds super().__init__( - f"Insufficient funds: required {required_funds}, available {available_funds}" + f"Insufficient funds ({self.token_type.value}): required {displayable_amount(self.required_funds, decimals=8)}, available {displayable_amount(self.available_funds, decimals=8)}" ) + + +class InvalidHashError(QueryError): + """The Hash is not valid""" + + pass diff --git a/src/aleph/sdk/query/filters.py b/src/aleph/sdk/query/filters.py index 4caee5f5..18f8b3f7 100644 --- a/src/aleph/sdk/query/filters.py +++ b/src/aleph/sdk/query/filters.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Dict, Iterable, Optional, Union -from aleph_message.models import MessageType +from aleph_message.models import Chain, MessageType from ..utils import _date_field_to_timestamp, enum_as_str, serialize_list @@ -56,6 +56,7 @@ class MessageFilter: def __init__( self, message_types: Optional[Iterable[MessageType]] = None, + message_statuses: Optional[Iterable[str]] = None, content_types: Optional[Iterable[str]] = None, content_keys: Optional[Iterable[str]] = None, refs: Optional[Iterable[str]] = None, @@ -82,6 +83,7 @@ def __init__( self.end_date = end_date self.sort_by = sort_by self.sort_order = sort_order + self.message_statuses = message_statuses def as_http_params(self) -> Dict[str, str]: """Convert the filters into a dict that can be used by an `aiohttp` client @@ -95,6 +97,7 @@ def as_http_params(self) -> Dict[str, str]: else None ), "contentTypes": serialize_list(self.content_types), + "message_statuses": serialize_list(self.message_statuses), "contentKeys": serialize_list(self.content_keys), "refs": serialize_list(self.refs), "addresses": serialize_list(self.addresses), @@ -193,3 +196,35 @@ def as_http_params(self) -> Dict[str, str]: result[key] = value return result + + +class BalanceFilter: + """ + A collection of filters that can be applied on Balance queries. + """ + + chain: Optional[Chain] + + def __init__( + self, + chain: Optional[Chain] = None, + ): + self.chain = chain + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = {"chain": enum_as_str(self.chain)} + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result diff --git a/src/aleph/sdk/query/responses.py b/src/aleph/sdk/query/responses.py index 5fb91804..6efade14 100644 --- a/src/aleph/sdk/query/responses.py +++ b/src/aleph/sdk/query/responses.py @@ -1,5 +1,7 @@ from __future__ import annotations +import datetime as dt +from decimal import Decimal from typing import Any, Dict, List, Optional, Union from aleph_message.models import ( @@ -9,7 +11,7 @@ ItemType, MessageConfirmation, ) -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class Post(BaseModel): @@ -48,9 +50,9 @@ class Post(BaseModel): ref: Optional[Union[str, Any]] = Field( description="Other message referenced by this one" ) + address: Optional[str] = Field(description="Address of the sender") - class Config: - allow_extra = False + model_config = ConfigDict(extra="forbid") class PaginationResponse(BaseModel): @@ -64,11 +66,51 @@ class PostsResponse(PaginationResponse): """Response from an aleph.im node API on the path /api/v0/posts.json""" posts: List[Post] - pagination_item = "posts" + pagination_item: str = "posts" class MessagesResponse(PaginationResponse): """Response from an aleph.im node API on the path /api/v0/messages.json""" messages: List[AlephMessage] - pagination_item = "messages" + pagination_item: str = "messages" + + +class PriceResponse(BaseModel): + """Response from an aleph.im node API on the path /api/v0/price/{item_hash}""" + + required_tokens: Decimal + cost: Optional[str] = None + payment_type: str + + +class CreditsHistoryResponse(PaginationResponse): + """Response from an aleph.im node API on the path /api/v0/credits""" + + address: str + credit_history: List[CreditHistoryResponseItem] + pagination_item: str = "credit_history" + + +class CreditHistoryResponseItem(BaseModel): + amount: int + ratio: Optional[Decimal] = None + tx_hash: Optional[str] = None + token: Optional[str] = None + chain: Optional[str] = None + provider: Optional[str] = None + origin: Optional[str] = None + origin_ref: Optional[str] = None + payment_method: Optional[str] = None + credit_ref: str + credit_index: int + expiration_date: Optional[dt.datetime] = None + message_timestamp: dt.datetime + + +class BalanceResponse(BaseModel): + address: str + balance: Decimal + details: Optional[Dict[str, Decimal]] = None + locked_amount: Decimal + credit_balance: int = 0 diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index 2f57b280..1bbe66d1 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -1,10 +1,42 @@ from abc import abstractmethod +from datetime import datetime +from decimal import Decimal from enum import Enum -from typing import Dict, Protocol, TypeVar +from typing import ( + Any, + Dict, + Iterator, + List, + Literal, + Optional, + Protocol, + TypeVar, + Union, +) -__all__ = ("StorageEnum", "Account", "AccountFromPrivateKey", "GenericMessage") +from aleph_message.models import ItemHash, MessageType +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PositiveInt, + RootModel, + TypeAdapter, + field_validator, +) +from typing_extensions import Self, runtime_checkable -from aleph_message.models import AlephMessage +__all__ = ( + "Authorization", + "AuthorizationBuilder", + "StorageEnum", + "Account", + "AccountFromPrivateKey", + "HardwareAccount", + "GenericMessage", +) + +from aleph_message.models import AlephMessage, Chain class StorageEnum(str, Enum): @@ -13,6 +45,7 @@ class StorageEnum(str, Enum): # Use a protocol to avoid importing crypto libraries +@runtime_checkable class Account(Protocol): CHAIN: str CURVE: str @@ -20,6 +53,9 @@ class Account(Protocol): @abstractmethod async def sign_message(self, message: Dict) -> Dict: ... + @abstractmethod + async def sign_raw(self, buffer: bytes) -> bytes: ... + @abstractmethod def get_address(self) -> str: ... @@ -27,10 +63,413 @@ def get_address(self) -> str: ... def get_public_key(self) -> str: ... +@runtime_checkable class AccountFromPrivateKey(Account, Protocol): """Only accounts that are initialized from a private key string are supported.""" - def __init__(self, private_key: bytes): ... + def __init__(self, private_key: bytes, chain: Chain): ... + + async def sign_raw(self, buffer: bytes) -> bytes: ... + + def export_private_key(self) -> str: ... + + def switch_chain(self, chain: Optional[str] = None) -> None: ... + + +@runtime_checkable +class HardwareAccount(Account, Protocol): + """Account using hardware wallet.""" + + @staticmethod + def from_address( + address: str, device: Optional[Any] = None + ) -> Optional["HardwareAccount"]: ... + + @staticmethod + def from_path(path: str, device: Optional[Any] = None) -> "HardwareAccount": ... + + def get_address(self) -> str: ... + + def switch_chain(self, chain: Optional[str] = None) -> None: ... + + async def sign_message(self, message: Dict) -> Dict: ... + + async def sign_raw(self, buffer: bytes) -> bytes: ... GenericMessage = TypeVar("GenericMessage", bound=AlephMessage) + + +class SEVInfo(BaseModel): + """ + An AMD SEV platform information. + """ + + enabled: bool + api_major: int + api_minor: int + build_id: int + policy: int + state: str + handle: int + + +class SEVMeasurement(BaseModel): + """ + A SEV measurement data get from Qemu measurement. + """ + + sev_info: SEVInfo + launch_measure: str + + +class ChainInfo(BaseModel): + """ + A chain information. + """ + + chain_id: int + rpc: str + token: Optional[str] = None + super_token: Optional[str] = None + active: bool = True + + +class StoredContent(BaseModel): + """ + A stored content. + """ + + filename: Optional[str] = Field(default=None) + hash: Optional[str] = Field(default=None) + url: Optional[str] = Field(default=None) + error: Optional[str] = Field(default=None) + + +class TokenType(str, Enum): + """ + A token type. + """ + + GAS = "GAS" + ALEPH = "ALEPH" + CREDIT = "CREDIT" + + +# Scheduler +class Period(BaseModel): + start_timestamp: datetime + duration_seconds: float + + +class PlanItem(BaseModel): + persistent_vms: List[ItemHash] = Field(default_factory=list) + instances: List[ItemHash] = Field(default_factory=list) + on_demand_vms: List[ItemHash] = Field(default_factory=list) + jobs: List[str] = Field(default_factory=list) # adjust type if needed + + @field_validator( + "persistent_vms", "instances", "on_demand_vms", "jobs", mode="before" + ) + @classmethod + def coerce_to_list(cls, v: Any) -> List[Any]: + # Treat None or empty dict as empty list + if v is None or (isinstance(v, dict) and not v): + return [] + return v + + +class SchedulerPlan(BaseModel): + period: Period + plan: Dict[str, PlanItem] + + model_config = { + "populate_by_name": True, + } + + +class NodeItem(BaseModel): + node_id: str + url: str + ipv6: Optional[str] = None + supports_ipv6: bool + + +class SchedulerNodes(BaseModel): + nodes: List[NodeItem] + + model_config = { + "populate_by_name": True, + } + + def get_url(self, node_id: str) -> Optional[str]: + """ + Return the URL for the given node_id, or None if not found. + """ + for node in self.nodes: + if node.node_id == node_id: + return node + return None + + +class AllocationItem(BaseModel): + vm_hash: ItemHash + vm_type: str + vm_ipv6: Optional[str] = None + period: Period + node: NodeItem + + model_config = { + "populate_by_name": True, + } + + +class InstanceWithScheduler(BaseModel): + source: Literal["scheduler"] + allocations: Optional[ + AllocationItem + ] # Case Scheduler (None == allocation can't be find on scheduler) + + +class InstanceManual(BaseModel): + source: Literal["manual"] + crn_url: str # Case + + +class InstanceAllocationsInfo( + RootModel[Dict[ItemHash, Union[InstanceManual, InstanceWithScheduler]]] +): + """ + RootModel holding mapping ItemHash to its Allocations. + Uses item_hash as the key instead of InstanceMessage objects to avoid hashability issues. + """ + + pass + + +# CRN Executions + + +class Networking(BaseModel): + ipv4: str + ipv6: str + + +class CrnExecutionV1(BaseModel): + networking: Networking + + +class PortMapping(BaseModel): + host: int + tcp: bool + udp: bool + + +class NetworkingV2(BaseModel): + ipv4_network: str + host_ipv4: str + ipv6_network: str + ipv6_ip: str + mapped_ports: Dict[str, PortMapping] + + +class VmStatus(BaseModel): + defined_at: Optional[datetime] + preparing_at: Optional[datetime] + prepared_at: Optional[datetime] + starting_at: Optional[datetime] + started_at: Optional[datetime] + stopping_at: Optional[datetime] + stopped_at: Optional[datetime] + + +class CrnExecutionV2(BaseModel): + networking: NetworkingV2 + status: VmStatus + running: bool + + +class CrnV1List(RootModel[Dict[ItemHash, CrnExecutionV1]]): + """ + V1: a dict whose keys are ItemHash (strings) + and whose values are VmItemV1 (just `networking`). + """ + + pass + + +class CrnV2List(RootModel[Dict[ItemHash, CrnExecutionV2]]): + """ + A RootModel whose root is a dict mapping each item‐hash (string) + to a CrnExecutionV2, exactly matching your JSON structure. + """ + + pass + + +class InstancesExecutionList( + RootModel[Dict[ItemHash, Union[CrnExecutionV1, CrnExecutionV2]]] +): + """ + A Root Model representing Instances Message hashes and their Executions. + Uses ItemHash as keys to avoid hashability issues with InstanceMessage objects. + """ + + pass + + +class IPV4(BaseModel): + public: str + local: str + + +class Dns(BaseModel): + name: str + item_hash: ItemHash + ipv4: Optional[IPV4] + ipv6: str + + +DnsListAdapter = TypeAdapter(list[Dns]) + + +class PortFlags(BaseModel): + tcp: bool + udp: bool + + +class Ports(BaseModel): + ports: Dict[int, PortFlags] + + +AllForwarders = RootModel[Dict[ItemHash, Optional[Ports]]] + + +class DictLikeModel(BaseModel): + """ + Base class: behaves like a dict while still being a Pydantic model. + """ + + # allow extra fields + validate on assignment + model_config = ConfigDict(extra="allow", validate_assignment=True) + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __setitem__(self, key: str, value: Any) -> None: + setattr(self, key, value) + + def __iter__(self) -> Iterator[str]: + return iter(self.model_dump().keys()) + + def __contains__(self, key: str) -> bool: + return hasattr(self, key) + + def keys(self): + return self.model_dump().keys() + + def values(self): + return self.model_dump().values() + + def items(self): + return self.model_dump().items() + + def get(self, key: str, default=None): + return getattr(self, key, default) + + +class VoucherAttribute(BaseModel): + value: Union[str, Decimal] + trait_type: str = Field(..., alias="trait_type") + display_type: Optional[str] = Field(None, alias="display_type") + + +class VoucherMetadata(BaseModel): + name: str + description: str + external_url: str + image: str + icon: str + attributes: list[VoucherAttribute] + + +class Voucher(BaseModel): + id: str + metadata_id: str + name: str + description: str + external_url: str + image: str + icon: str + attributes: list[VoucherAttribute] + + +class VmResources(BaseModel): + vcpus: PositiveInt + memory: PositiveInt + disk_mib: PositiveInt + + +class Authorization(BaseModel): + """A single authorization entry for delegated access.""" + + address: str + chain: Optional[Chain] = None + channels: list[str] = [] + types: list[MessageType] = [] + post_types: list[str] = [] + aggregate_keys: list[str] = [] + + +class AuthorizationBuilder: + def __init__(self, address: str): + self._address: str = address + self._chain: Optional[Chain] = None + self._channels: list[str] = [] + self._message_types: list[MessageType] = [] + self._post_types: list[str] = [] + self._aggregate_keys: list[str] = [] + + def chain(self, chain: Chain) -> Self: + self._chain = chain + return self + + def channel(self, channel: str) -> Self: + self._channels.append(channel) + return self + + def message_type(self, message_type: MessageType) -> Self: + self._message_types.append(message_type) + return self + + def post_type(self, post_type: str) -> Self: + if MessageType.post not in self._message_types: + raise ValueError( + "Cannot set post_type without allowing POST message type first" + ) + self._post_types.append(post_type) + return self + + def aggregate_key(self, aggregate_key: str) -> Self: + if MessageType.aggregate not in self._message_types: + raise ValueError( + "Cannot set post_type without allowing AGGREGATE message type first" + ) + self._aggregate_keys.append(aggregate_key) + return self + + def build(self) -> Authorization: + return Authorization( + address=self._address, + chain=self._chain, + channels=self._channels, + types=self._message_types, + post_types=self._post_types, + aggregate_keys=self._aggregate_keys, + ) + + +class SecurityAggregateContent(BaseModel): + """Content schema for the 'security' aggregate.""" + + authorizations: list[Authorization] = [] diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index ab17f44a..94bc3bb9 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,12 +1,21 @@ +import asyncio +import base64 import errno +import hashlib +import hmac +import json import logging import os +import re +import subprocess from datetime import date, datetime, time +from decimal import Context, Decimal, InvalidOperation from enum import Enum from pathlib import Path from shutil import make_archive from typing import ( Any, + Dict, Iterable, Mapping, Optional, @@ -17,15 +26,49 @@ Union, get_args, ) +from urllib.parse import urlparse +from uuid import UUID from zipfile import BadZipFile, ZipFile -from aleph_message.models import MessageType -from aleph_message.models.execution.program import Encoding -from aleph_message.models.execution.volume import MachineVolume -from pydantic.json import pydantic_encoder +import pydantic_core +from aleph_message.models import ( + Chain, + InstanceContent, + ItemHash, + MachineType, + MessageType, + ProgramContent, +) +from aleph_message.models.execution.base import Payment, PaymentType +from aleph_message.models.execution.environment import ( + FunctionEnvironment, + FunctionTriggers, + HostRequirements, + HypervisorType, + InstanceEnvironment, + MachineResources, + Subscription, + TrustedExecutionEnvironment, +) +from aleph_message.models.execution.instance import RootfsVolume +from aleph_message.models.execution.program import ( + CodeContent, + Encoding, + FunctionRuntime, +) +from aleph_message.models.execution.volume import ( + MachineVolume, + ParentVolume, + PersistentVolumeSizeMib, + VolumePersistence, +) +from aleph_message.utils import Mebibytes +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from jwcrypto.jwa import JWA from aleph.sdk.conf import settings -from aleph.sdk.types import GenericMessage +from aleph.sdk.types import GenericMessage, SEVInfo, SEVMeasurement logger = logging.getLogger(__name__) @@ -161,20 +204,422 @@ def extended_json_encoder(obj: Any) -> Any: elif isinstance(obj, time): return obj.hour * 3600 + obj.minute * 60 + obj.second + obj.microsecond / 1e6 else: - return pydantic_encoder(obj) + return pydantic_core.to_jsonable_python(obj) def parse_volume(volume_dict: Union[Mapping, MachineVolume]) -> MachineVolume: - # Python 3.9 does not support `isinstance(volume_dict, MachineVolume)`, - # so we need to iterate over all types. if any( isinstance(volume_dict, volume_type) for volume_type in get_args(MachineVolume) ): - return volume_dict + return volume_dict # type: ignore + for volume_type in get_args(MachineVolume): try: - return volume_type.parse_obj(volume_dict) + return volume_type.model_validate(volume_dict) except ValueError: - continue - else: - raise ValueError(f"Could not parse volume: {volume_dict}") + pass + raise ValueError(f"Could not parse volume: {volume_dict}") + + +def compute_sha256(s: str) -> str: + """Compute the SHA256 hash of a string.""" + return hashlib.sha256(s.encode()).hexdigest() + + +def to_0x_hex(b: bytes) -> str: + return "0x" + bytes.hex(b) + + +def bytes_from_hex(hex_string: str) -> bytes: + if hex_string.startswith("0x"): + hex_string = hex_string[2:] + hex_string = bytes.fromhex(hex_string) + return hex_string + + +def create_vm_control_payload( + vm_id: ItemHash, operation: str, domain: str, method: str +) -> Dict[str, str]: + path = f"/control/machine/{vm_id}/{operation}" + payload = { + "time": datetime.utcnow().isoformat() + "Z", + "method": method.upper(), + "path": path, + "domain": domain, + } + return payload + + +def sign_vm_control_payload(payload: Dict[str, str], ephemeral_key) -> str: + payload_as_bytes = json.dumps(payload).encode("utf-8") + payload_signature = JWA.signing_alg("ES256").sign(ephemeral_key, payload_as_bytes) + signed_operation = json.dumps( + { + "payload": payload_as_bytes.hex(), + "signature": payload_signature.hex(), + } + ) + return signed_operation + + +async def run_in_subprocess( + command: list[str], check: bool = True, stdin_input: Optional[bytes] = None +) -> bytes: + """Run the specified command in a subprocess, returns the stdout of the process.""" + logger.debug(f"command: {' '.join(command)}") + + process = await asyncio.create_subprocess_exec( + *command, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate(input=stdin_input) + + if check and process.returncode: + logger.error( + f"Command failed with error code {process.returncode}:\n" + f" stdin = {stdin_input!r}\n" + f" command = {command}\n" + f" stdout = {stderr!r}" + ) + raise subprocess.CalledProcessError( + process.returncode, str(command), stderr.decode() + ) + + return stdout + + +def get_vm_measure(sev_data: SEVMeasurement) -> Tuple[bytes, bytes]: + launch_measure = base64.b64decode(sev_data.launch_measure) + vm_measure = launch_measure[0:32] + nonce = launch_measure[32:48] + return vm_measure, nonce + + +def calculate_firmware_hash(firmware_path: Path) -> str: + """Calculate the hash of the firmware (OVMF) file to be used in validating the measurements + + Returned as hex encoded string""" + + # https://www.qemu.org/docs/master/system/i386/amd-memory-encryption.html + # The value of GCTX.LD is SHA256(firmware_blob || kernel_hashes_blob || vmsas_blob), where: + # firmware_blob is the content of the entire firmware flash file (for example, OVMF.fd). [...] + # and verified again sevctl, see tests + firmware_content = firmware_path.read_bytes() + hash_calculator = hashlib.sha256(firmware_content) + + return hash_calculator.hexdigest() + + +def compute_confidential_measure( + sev_info: SEVInfo, tik: bytes, expected_hash: str, nonce: bytes +) -> hmac.HMAC: + """ + Computes the SEV measurement using the CRN SEV data and local variables like the OVMF firmware hash, + and the session key generated. + """ + + h = hmac.new(tik, digestmod="sha256") + + ## + # calculated per section 6.5.2 + ## + h.update(bytes([0x04])) + h.update(sev_info.api_major.to_bytes(1, byteorder="little")) + h.update(sev_info.api_minor.to_bytes(1, byteorder="little")) + h.update(sev_info.build_id.to_bytes(1, byteorder="little")) + h.update(sev_info.policy.to_bytes(4, byteorder="little")) + + expected_hash_bytes = bytearray.fromhex(expected_hash) + h.update(expected_hash_bytes) + + h.update(nonce) + + return h + + +def make_secret_table(secret: str) -> bytearray: + """ + Makes the disk secret table to be sent to the Confidential CRN + """ + + ## + # Construct the secret table: two guids + 4 byte lengths plus string + # and zero terminator + # + # Secret layout is guid, len (4 bytes), data + # with len being the length from start of guid to end of data + # + # The table header covers the entire table then each entry covers + # only its local data + # + # our current table has the header guid with total table length + # followed by the secret guid with the zero terminated secret + ## + + # total length of table: header plus one entry with trailing \0 + length = 16 + 4 + 16 + 4 + len(secret) + 1 + # SEV-ES requires rounding to 16 + length = (length + 15) & ~15 + secret_table = bytearray(length) + + secret_table[0:16] = UUID("{1e74f542-71dd-4d66-963e-ef4287ff173b}").bytes_le + secret_table[16:20] = len(secret_table).to_bytes(4, byteorder="little") + secret_table[20:36] = UUID("{736869e5-84f0-4973-92ec-06879ce3da0b}").bytes_le + secret_table[36:40] = (16 + 4 + len(secret) + 1).to_bytes(4, byteorder="little") + secret_table[40 : 40 + len(secret)] = secret.encode() + + return secret_table + + +def encrypt_secret_table(secret_table: bytes, tek: bytes, iv: bytes) -> bytes: + """Encrypt the secret table with the TEK in CTR mode using a random IV""" + + # Initialize the cipher with AES algorithm and CTR mode + cipher = Cipher(algorithms.AES(tek), modes.CTR(iv), backend=default_backend()) + encryptor = cipher.encryptor() + + # Encrypt the secret table + encrypted_secret = encryptor.update(secret_table) + encryptor.finalize() + + return encrypted_secret + + +def make_packet_header( + vm_measure: bytes, + encrypted_secret_table: bytes, + secret_table_size: int, + tik: bytes, + iv: bytes, +) -> bytearray: + """ + Creates a packet header using the encrypted disk secret table to be sent to the Confidential CRN + """ + + ## + # ultimately needs to be an argument, but there's only + # compressed and no real use case + ## + flags = 0 + + ## + # Table 55. LAUNCH_SECRET Packet Header Buffer + ## + header = bytearray(52) + header[0:4] = flags.to_bytes(4, byteorder="little") + header[4:20] = iv + + h = hmac.new(tik, digestmod="sha256") + h.update(bytes([0x01])) + # FLAGS || IV + h.update(header[0:20]) + h.update(secret_table_size.to_bytes(4, byteorder="little")) + h.update(secret_table_size.to_bytes(4, byteorder="little")) + h.update(encrypted_secret_table) + h.update(vm_measure) + + header[20:52] = h.digest() + + return header + + +def safe_getattr(obj, attr, default=None): + for part in attr.split("."): + obj = getattr(obj, part, default) + if obj is default: + break + return obj + + +def displayable_amount( + amount: Union[str, int, float, Decimal], decimals: int = 18 +) -> str: + """Returns the amount as a string without unnecessary decimals.""" + + str_amount = "" + try: + dec_amount = Decimal(amount) + if decimals: + dec_amount = dec_amount.quantize( + Decimal(1) / Decimal(10**decimals), context=Context(prec=36) + ) + str_amount = str(format(dec_amount.normalize(), "f")) + except ValueError: + logger.error(f"Invalid amount to display: {amount}") + exit(1) + except InvalidOperation: + logger.error(f"Invalid operation on amount to display: {amount}") + exit(1) + return str_amount + + +def make_instance_content( + rootfs: str, + rootfs_size: int, + payment: Optional[Payment] = None, + environment_variables: Optional[dict[str, str]] = None, + address: Optional[str] = None, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + allow_amend: bool = False, + internet: bool = True, + aleph_api: bool = True, + hypervisor: Optional[HypervisorType] = None, + trusted_execution: Optional[TrustedExecutionEnvironment] = None, + volumes: Optional[list[Mapping]] = None, + ssh_keys: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + requirements: Optional[HostRequirements] = None, +) -> InstanceContent: + """ + Create InstanceContent object given the provided fields. + """ + + address = address or "0x0000000000000000000000000000000000000000" + payment = payment or Payment(chain=Chain.ETH, type=PaymentType.hold, receiver=None) + selected_hypervisor: HypervisorType = hypervisor or HypervisorType.qemu + vcpus = vcpus or settings.DEFAULT_VM_VCPUS + memory = memory or settings.DEFAULT_VM_MEMORY + timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT + volumes = volumes if volumes is not None else [] + + return InstanceContent( + address=address, + allow_amend=allow_amend, + environment=InstanceEnvironment( + internet=internet, + aleph_api=aleph_api, + hypervisor=selected_hypervisor, + trusted_execution=trusted_execution, + ), + variables=environment_variables, + resources=MachineResources( + vcpus=vcpus, + memory=Mebibytes(memory), + seconds=int(timeout_seconds), + ), + rootfs=RootfsVolume( + parent=ParentVolume( + ref=ItemHash(rootfs), + use_latest=True, + ), + size_mib=PersistentVolumeSizeMib(rootfs_size), + persistence=VolumePersistence.host, + ), + volumes=[parse_volume(volume) for volume in volumes], + requirements=requirements, + time=datetime.now().timestamp(), + authorized_keys=ssh_keys, + metadata=metadata, + payment=payment, + ) + + +def make_program_content( + program_ref: str, + entrypoint: str, + runtime: str, + metadata: Optional[dict[str, Any]] = None, + address: Optional[str] = None, + vcpus: Optional[int] = None, + memory: Optional[int] = None, + timeout_seconds: Optional[float] = None, + internet: bool = False, + aleph_api: bool = True, + allow_amend: bool = False, + encoding: Encoding = Encoding.zip, + persistent: bool = False, + volumes: Optional[list[Mapping]] = None, + environment_variables: Optional[dict[str, str]] = None, + subscriptions: Optional[list[dict]] = None, + payment: Optional[Payment] = None, +) -> ProgramContent: + """ + Create ProgramContent object given the provided fields. + """ + + address = address or "0x0000000000000000000000000000000000000000" + payment = payment or Payment(chain=Chain.ETH, type=PaymentType.hold, receiver=None) + vcpus = vcpus or settings.DEFAULT_VM_VCPUS + memory = memory or settings.DEFAULT_VM_MEMORY + timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT + volumes = volumes if volumes is not None else [] + subscriptions = ( + [Subscription(**sub) for sub in subscriptions] + if subscriptions is not None + else None + ) + + return ProgramContent( + type=MachineType.vm_function, + address=address, + allow_amend=allow_amend, + code=CodeContent( + encoding=encoding, + entrypoint=entrypoint, + ref=ItemHash(program_ref), + use_latest=True, + ), + on=FunctionTriggers( + http=True, + persistent=persistent, + message=subscriptions, + ), + environment=FunctionEnvironment( + reproducible=False, + internet=internet, + aleph_api=aleph_api, + ), + variables=environment_variables, + resources=MachineResources( + vcpus=vcpus, + memory=Mebibytes(memory), + seconds=int(timeout_seconds), + ), + runtime=FunctionRuntime( + ref=ItemHash(runtime), + use_latest=True, + comment=( + "Official aleph.im runtime" + if runtime == settings.DEFAULT_RUNTIME_ID + else "" + ), + ), + volumes=[parse_volume(volume) for volume in volumes], + time=datetime.now().timestamp(), + metadata=metadata, + authorized_keys=[], + payment=payment, + ) + + +def sanitize_url(url: str) -> str: + """ + Sanitize a URL by removing the trailing slash and ensuring it's properly formatted. + + Args: + url: The URL to sanitize + + Returns: + The sanitized URL + """ + # Remove trailing slash if present + url = url.rstrip("/") + + # Ensure URL has a proper scheme + parsed = urlparse(url) + if not parsed.scheme: + url = f"https://{url}" + + return url + + +def extract_valid_eth_address(address: str) -> str: + if address: + pattern = r"0x[a-fA-F0-9]{40}" + match = re.search(pattern, address) + if match: + return match.group(0) + return "" diff --git a/src/aleph/sdk/vm/cache.py b/src/aleph/sdk/vm/cache.py index ff5ca7c8..a7ac6acc 100644 --- a/src/aleph/sdk/vm/cache.py +++ b/src/aleph/sdk/vm/cache.py @@ -70,7 +70,7 @@ def __init__( ) self.cache = {} - self.api_host = connector_url if connector_url else settings.API_HOST + self.api_host = str(connector_url) if connector_url else settings.API_HOST async def get(self, key: str) -> Optional[bytes]: sanitized_key = sanitize_cache_key(key) diff --git a/src/aleph/sdk/wallets/ledger/ethereum.py b/src/aleph/sdk/wallets/ledger/ethereum.py index 2ecdc5d3..a09958d6 100644 --- a/src/aleph/sdk/wallets/ledger/ethereum.py +++ b/src/aleph/sdk/wallets/ledger/ethereum.py @@ -1,34 +1,58 @@ from __future__ import annotations +import asyncio +import logging from typing import Dict, List, Optional +from aleph_message.models import Chain from eth_typing import HexStr from ledgerblue.Dongle import Dongle from ledgereth import find_account, get_account_by_path, get_accounts from ledgereth.comms import init_dongle from ledgereth.messages import sign_message from ledgereth.objects import LedgerAccount, SignedMessage +from ledgereth.transactions import Type2Transaction, sign_transaction +from web3.types import TxReceipt -from ...chains.common import BaseAccount, bytes_from_hex, get_verification_buffer +from ...chains.common import get_verification_buffer +from ...chains.ethereum import BaseEthAccount +from ...utils import bytes_from_hex +logger = logging.getLogger(__name__) -class LedgerETHAccount(BaseAccount): + +class LedgerETHAccount(BaseEthAccount): """Account using the Ethereum app on Ledger hardware wallets.""" - CHAIN = "ETH" - CURVE = "secp256k1" _account: LedgerAccount _device: Dongle - def __init__(self, account: LedgerAccount, device: Dongle): + def __init__( + self, account: LedgerAccount, device: Dongle, chain: Optional[Chain] = None + ): """Initialize an aleph.im account instance that relies on a LedgerHQ device and the Ethereum Ledger application for signatures. See the static methods `self.from_address(...)` and `self.from_path(...)` for an easier method of instantiation. """ + super().__init__(chain=None) + self._account = account self._device = device + if chain: + self.connect_chain(chain=chain) + + @staticmethod + def get_accounts( + device: Optional[Dongle] = None, count: int = 5 + ) -> List[LedgerAccount]: + """Initialize an aleph.im account from a LedgerHQ device from + a known wallet address. + """ + device = device or init_dongle() + accounts: List[LedgerAccount] = get_accounts(dongle=device, count=count) + return accounts @staticmethod def from_address( @@ -67,7 +91,12 @@ async def sign_message(self, message: Dict) -> Dict: # TODO: Check why the code without a wallet uses `encode_defunct`. msghash: bytes = get_verification_buffer(message) - sig: SignedMessage = sign_message(msghash, dongle=self._device) + logger.warning( + "Please Sign messages using ledger" + ) # allow to propagate it to cli + sig: SignedMessage = sign_message( + msghash, dongle=self._device, sender_path=self._account.path + ) signature: HexStr = sig.signature @@ -76,10 +105,66 @@ async def sign_message(self, message: Dict) -> Dict: async def sign_raw(self, buffer: bytes) -> bytes: """Sign a raw buffer.""" - sig: SignedMessage = sign_message(buffer, dongle=self._device) + logger.warning( + "Please sign the message on your Ledger device" + ) # allow to propagate it to cli + sig: SignedMessage = sign_message( + buffer, dongle=self._device, sender_path=self._account.path + ) signature: HexStr = sig.signature return bytes_from_hex(signature) + async def _sign_and_send_transaction(self, tx_params: dict) -> str: + """ + Sign and broadcast a transaction using the Ledger hardware wallet. + Equivalent of the software _sign_and_send_transaction(). + + @param tx_params: dict - Transaction parameters + @returns: str - Transaction hash + """ + if self._provider is None: + raise ValueError("Provider not connected") + + def sign_and_send() -> TxReceipt: + logger.warning( + "Please Sign messages using ledger" + ) # allow to propagate it to cli + + # Type2Transaction + tx = Type2Transaction( + chain_id=tx_params["chainId"], + nonce=tx_params["nonce"], + max_priority_fee_per_gas=tx_params["maxPriorityFeePerGas"], + max_fee_per_gas=tx_params["maxFeePerGas"], + gas_limit=tx_params["gas"], + destination=bytes.fromhex(tx_params["to"][2:]), + amount=tx_params["value"], + data=bytes.fromhex(tx_params["data"][2:]), + ) + signed_tx = sign_transaction( + tx=tx, + sender_path=self._account.path, + dongle=self._device, + ) + + provider = self._provider + if provider is None: + raise ValueError("Provider not connected") + + tx_hash = provider.eth.send_raw_transaction(signed_tx.rawTransaction) + + tx_receipt = provider.eth.wait_for_transaction_receipt( + tx_hash, + timeout=getattr(self, "TX_TIMEOUT", 120), # optional custom timeout + ) + + return tx_receipt + + loop = asyncio.get_running_loop() + tx_receipt = await loop.run_in_executor(None, sign_and_send) + + return tx_receipt["transactionHash"].hex() + def get_address(self) -> str: return self._account.address diff --git a/tests/integration/config.py b/tests/integration/config.py index 3e613c18..bd78ef1a 100644 --- a/tests/integration/config.py +++ b/tests/integration/config.py @@ -1,3 +1,3 @@ -TARGET_NODE = "https://api1.aleph.im" +TARGET_NODE = "https://api3.aleph.im" REFERENCE_NODE = "https://api2.aleph.im" TEST_CHANNEL = "INTEGRATION_TESTS" diff --git a/tests/integration/fixtures/testStore.txt b/tests/integration/fixtures/testStore.txt new file mode 100644 index 00000000..865be812 --- /dev/null +++ b/tests/integration/fixtures/testStore.txt @@ -0,0 +1 @@ +Never gonna give you up. \ No newline at end of file diff --git a/tests/integration/itest_store.py b/tests/integration/itest_store.py new file mode 100644 index 00000000..0831bee3 --- /dev/null +++ b/tests/integration/itest_store.py @@ -0,0 +1,65 @@ +import pytest + +from aleph.sdk.client import AuthenticatedAlephHttpClient +from aleph.sdk.query.filters import MessageFilter +from tests.integration.toolkit import has_messages, try_until + +from .config import REFERENCE_NODE, TARGET_NODE + + +async def create_store_on_target(account, emitter_node: str, receiver_node: str): + """ + Create a POST message on the target node, then fetch it from the reference node and download the file. + """ + with open("tests/integration/fixtures/testStore.txt", "rb") as f: + file_content = f.read() + async with AuthenticatedAlephHttpClient( + account=account, api_server=emitter_node + ) as tx_session: + store_message, message_status = await tx_session.create_store( + file_content=file_content, + extra_fields={"test": "test"}, + ) + + async with AuthenticatedAlephHttpClient( + account=account, api_server=receiver_node + ) as rx_session: + responses = await try_until( + rx_session.get_messages, + has_messages, + timeout=5, + message_filter=MessageFilter( + hashes=[store_message.item_hash], + ), + ) + + message_from_target = responses.messages[0] + assert store_message.item_hash == message_from_target.item_hash + + async with AuthenticatedAlephHttpClient( + account=account, api_server=receiver_node + ) as rx_session: + store_content = await rx_session.download_file(store_message.content.item_hash) + assert store_content == file_content + + +@pytest.mark.asyncio +async def test_create_message_on_target(fixture_account): + """ + Attempts to create a new message on the target node and verifies if the message can be fetched from + the reference node. + """ + await create_store_on_target( + fixture_account, emitter_node=TARGET_NODE, receiver_node=REFERENCE_NODE + ) + + +@pytest.mark.asyncio +async def test_create_message_on_reference(fixture_account): + """ + Attempts to create a new message on the reference node and verifies if the message can be fetched from + the target node. + """ + await create_store_on_target( + fixture_account, emitter_node=REFERENCE_NODE, receiver_node=TARGET_NODE + ) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py new file mode 100644 index 00000000..d25b5177 --- /dev/null +++ b/tests/unit/aleph_vm_authentication.py @@ -0,0 +1,290 @@ +# Keep datetime import as is as it allow patching in test +from __future__ import annotations + +import datetime +import functools +import json +import logging +from collections.abc import Awaitable, Coroutine +from typing import Any, Callable, Dict, Literal, Optional, Union + +import cryptography.exceptions +import pydantic +from aiohttp import web +from eth_account import Account +from eth_account.messages import encode_defunct +from jwcrypto import jwk +from jwcrypto.jwa import JWA +from pydantic import BaseModel, ValidationError, field_validator, model_validator +from typing_extensions import Self + +from aleph.sdk.utils import bytes_from_hex + +logger = logging.getLogger(__name__) + +DOMAIN_NAME = "localhost" + + +def is_token_still_valid(datestr: str) -> bool: + """ + Checks if a token has expired based on its expiry timestamp + """ + current_datetime = datetime.datetime.now(tz=datetime.timezone.utc) + expiry_datetime = datetime.datetime.fromisoformat(datestr.replace("Z", "+00:00")) + + return expiry_datetime > current_datetime + + +def verify_wallet_signature(signature: bytes, message: str, address: str) -> bool: + """ + Verifies a signature issued by a wallet + """ + enc_msg = encode_defunct(hexstr=message) + computed_address = Account.recover_message(enc_msg, signature=signature) + + return computed_address.lower() == address.lower() + + +class SignedPubKeyPayload(BaseModel): + """This payload is signed by the wallet of the user to authorize an ephemeral key to act on his behalf.""" + + pubkey: Dict[str, Any] + # {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC', + # 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'} + # alg: Literal["ECDSA"] + address: str + expires: str + + @property + def json_web_key(self) -> jwk.JWK: + """Return the ephemeral public key as Json Web Key""" + + return jwk.JWK(**self.pubkey) + + +class SignedPubKeyHeader(BaseModel): + signature: bytes + payload: bytes + + @field_validator("signature") + def signature_must_be_hex(cls, value: bytes) -> bytes: + """Convert the signature from hexadecimal to bytes""" + return bytes_from_hex(value.decode()) + + @field_validator("payload") + def payload_must_be_hex(cls, value: bytes) -> bytes: + """Convert the payload from hexadecimal to bytes""" + return bytes_from_hex(value.decode()) + + @model_validator(mode="after") # type: ignore + def check_expiry(self) -> Self: + """Check that the token has not expired""" + payload: bytes = self.payload + content = SignedPubKeyPayload.model_validate_json(payload) + + if not is_token_still_valid(content.expires): + msg = "Token expired" + raise ValueError(msg) + + return self + + @model_validator(mode="after") # type: ignore + def check_signature(self) -> Self: + signature: bytes = self.signature + payload: bytes = self.payload + content = SignedPubKeyPayload.model_validate_json(payload) + + if not verify_wallet_signature(signature, payload.hex(), content.address): + msg = "Invalid signature" + raise ValueError(msg) + + return self + + @property + def content(self) -> SignedPubKeyPayload: + """Return the content of the header""" + return SignedPubKeyPayload.model_validate_json(self.payload) + + +class SignedOperationPayload(BaseModel): + time: datetime.datetime + method: Union[Literal["POST"], Literal["GET"]] + domain: str + path: str + # body_sha256: str # disabled since there is no body + + @field_validator("time") + def time_is_current(cls, v: datetime.datetime) -> datetime.datetime: + """Check that the time is current and the payload is not a replay attack.""" + max_past = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta( + minutes=2 + ) + max_future = datetime.datetime.now( + tz=datetime.timezone.utc + ) + datetime.timedelta(minutes=2) + if v < max_past: + raise ValueError("Time is too far in the past") + if v > max_future: + raise ValueError("Time is too far in the future") + return v + + +class SignedOperation(BaseModel): + """This payload is signed by the ephemeral key authorized above.""" + + signature: bytes + payload: bytes + + @field_validator("signature") + def signature_must_be_hex(cls, value: str) -> bytes: + """Convert the signature from hexadecimal to bytes""" + + try: + if isinstance(value, bytes): + value = value.decode() + return bytes_from_hex(value) + except pydantic.ValidationError as error: + logger.warning(value) + raise error + + @field_validator("payload") + def payload_must_be_hex(cls, v) -> bytes: + """Convert the payload from hexadecimal to bytes""" + v = bytes.fromhex(v.decode()) + _ = SignedOperationPayload.model_validate_json(v) + return v + + @property + def content(self) -> SignedOperationPayload: + """Return the content of the header""" + return SignedOperationPayload.model_validate_json(self.payload) + + +def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader: + """Get the ephemeral public key that is signed by the wallet from the request headers.""" + signed_pubkey_header = request.headers.get("X-SignedPubKey") + + if not signed_pubkey_header: + raise web.HTTPBadRequest(reason="Missing X-SignedPubKey header") + + try: + return SignedPubKeyHeader.model_validate_json(signed_pubkey_header) + + except KeyError as error: + logger.debug(f"Missing X-SignedPubKey header: {error}") + raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey fields") from error + + except json.JSONDecodeError as error: + raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey format") from error + + except ValueError as errors: + logging.debug(errors) + + for err in errors.args[0]: + if isinstance(err.exc, json.JSONDecodeError): + raise web.HTTPBadRequest( + reason="Invalid X-SignedPubKey format" + ) from errors + + if str(err.exc) == "Token expired": + raise web.HTTPUnauthorized(reason="Token expired") from errors + + if str(err.exc) == "Invalid signature": + raise web.HTTPUnauthorized(reason="Invalid signature") from errors + else: + raise errors + + +def get_signed_operation(request: web.Request) -> SignedOperation: + """Get the signed operation public key that is signed by the ephemeral key from the request headers.""" + try: + signed_operation = request.headers["X-SignedOperation"] + return SignedOperation.model_validate_json(signed_operation) + except KeyError as error: + raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error + except json.JSONDecodeError as error: + raise web.HTTPBadRequest(reason="Invalid X-SignedOperation format") from error + except ValidationError as error: + logger.debug(f"Invalid X-SignedOperation fields: {error}") + raise web.HTTPBadRequest(reason="Invalid X-SignedOperation fields") from error + + +def verify_signed_operation( + signed_operation: SignedOperation, signed_pubkey: SignedPubKeyHeader +) -> str: + """Verify that the operation is signed by the ephemeral key authorized by the wallet.""" + pubkey = signed_pubkey.content.json_web_key + + try: + JWA.signing_alg("ES256").verify( + pubkey, signed_operation.payload, signed_operation.signature + ) + logger.debug("Signature verified") + + return signed_pubkey.content.address + + except cryptography.exceptions.InvalidSignature as e: + logger.debug("Failing to validate signature for operation", e) + + raise web.HTTPUnauthorized(reason="Signature could not verified") + + +async def authenticate_jwk( + request: web.Request, domain_name: Optional[str] = DOMAIN_NAME +) -> str: + """Authenticate a request using the X-SignedPubKey and X-SignedOperation headers.""" + signed_pubkey = get_signed_pubkey(request) + signed_operation = get_signed_operation(request) + + if signed_operation.content.domain != domain_name: + logger.debug( + f"Invalid domain '{signed_operation.content.domain}' != '{domain_name}'" + ) + raise web.HTTPUnauthorized(reason="Invalid domain") + + if signed_operation.content.path != request.path: + logger.debug( + f"Invalid path '{signed_operation.content.path}' != '{request.path}'" + ) + raise web.HTTPUnauthorized(reason="Invalid path") + if signed_operation.content.method != request.method: + logger.debug( + f"Invalid method '{signed_operation.content.method}' != '{request.method}'" + ) + raise web.HTTPUnauthorized(reason="Invalid method") + return verify_signed_operation(signed_operation, signed_pubkey) + + +async def authenticate_websocket_message( + message, domain_name: Optional[str] = DOMAIN_NAME +) -> str: + """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" + signed_pubkey = SignedPubKeyHeader.model_validate(message["X-SignedPubKey"]) + signed_operation = SignedOperation.model_validate(message["X-SignedOperation"]) + if signed_operation.content.domain != domain_name: + logger.debug( + f"Invalid domain '{signed_operation.content.domain}' != '{domain_name}'" + ) + raise web.HTTPUnauthorized(reason="Invalid domain") + return verify_signed_operation(signed_operation, signed_pubkey) + + +def require_jwk_authentication( + handler: Callable[[web.Request, str], Coroutine[Any, Any, web.StreamResponse]] +) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: + @functools.wraps(handler) + async def wrapper(request): + try: + authenticated_sender: str = await authenticate_jwk(request) + except web.HTTPException as e: + return web.json_response(data={"error": e.reason}, status=e.status) + except Exception as e: + # Unexpected make sure to log it + logging.exception(e) + raise + + # authenticated_sender is the authenticted wallet address of the requester (as a string) + response = await handler(request, authenticated_sender) + return response + + return wrapper diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7d388e36..5086703b 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,14 +1,18 @@ +import asyncio import json +from functools import wraps +from io import BytesIO from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional, Union from unittest.mock import AsyncMock, MagicMock import pytest as pytest +from aiohttp import ClientResponseError from aleph_message.models import AggregateMessage, AlephMessage, PostMessage import aleph.sdk.chains.ethereum as ethereum -import aleph.sdk.chains.sol as solana +import aleph.sdk.chains.solana as solana import aleph.sdk.chains.substrate as substrate import aleph.sdk.chains.tezos as tezos from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient @@ -67,7 +71,7 @@ def rejected_message(): @pytest.fixture def aleph_messages() -> List[AlephMessage]: return [ - AggregateMessage.parse_obj( + AggregateMessage.model_validate( { "item_hash": "5b26d949fe05e38f535ef990a89da0473f9d700077cced228f2d36e73fca1fd6", "type": "AGGREGATE", @@ -91,7 +95,7 @@ def aleph_messages() -> List[AlephMessage]: "confirmed": False, } ), - PostMessage.parse_obj( + PostMessage.model_validate( { "item_hash": "70f3798fdc68ce0ee03715a5547ee24e2c3e259bf02e3f5d1e4bf5a6f6a5e99f", "type": "POST", @@ -131,7 +135,9 @@ def json_post() -> dict: def raw_messages_response(aleph_messages) -> Callable[[int], Dict[str, Any]]: return lambda page: { "messages": ( - [message.dict() for message in aleph_messages] if int(page) == 1 else [] + [message.model_dump() for message in aleph_messages] + if int(page) == 1 + else [] ), "pagination_item": "messages", "pagination_page": int(page), @@ -160,7 +166,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): ... - async def raise_for_status(self): ... + def raise_for_status(self): ... @property def status(self): @@ -190,24 +196,64 @@ def mock_session_with_post_success( client = AuthenticatedAlephHttpClient( account=ethereum_account, api_server="http://localhost" ) - client.http_session = http_session + client._http_session = http_session return client -def make_custom_mock_response(resp_json, status=200) -> MockResponse: +def async_wrap(cls): + class AsyncWrapper: + def __init__(self, *args, **kwargs): + self._instance = cls(*args, **kwargs) + + def __getattr__(self, item): + attr = getattr(self._instance, item) + if callable(attr): + + @wraps(attr) + async def method(*args, **kwargs): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, attr, *args, **kwargs) + + return method + return attr + + return AsyncWrapper + + +AsyncBytesIO = async_wrap(BytesIO) + + +def make_custom_mock_response( + resp: Union[Dict[str, Any], bytes], status=200 +) -> MockResponse: class CustomMockResponse(MockResponse): + content: Optional[AsyncBytesIO] + async def json(self): - return resp_json + return resp + + def raise_for_status(self): + if status >= 400: + raise ClientResponseError(None, None, status=status) @property def status(self): return status - return CustomMockResponse(sync=True) + mock = CustomMockResponse(sync=True) + + try: + mock.content = AsyncBytesIO(resp) + except Exception as e: + print(e) + + return mock -def make_mock_get_session(get_return_value: Dict[str, Any]) -> AlephHttpClient: +def make_mock_get_session( + get_return_value: Union[Dict[str, Any], bytes] +) -> AlephHttpClient: class MockHttpSession(AsyncMock): def get(self, *_args, **_kwargs): return make_custom_mock_response(get_return_value) @@ -215,7 +261,22 @@ def get(self, *_args, **_kwargs): http_session = MockHttpSession() client = AlephHttpClient(api_server="http://localhost") - client.http_session = http_session + client._http_session = http_session + + return client + + +def make_mock_get_session_400( + get_return_value: Union[Dict[str, Any], bytes] +) -> AlephHttpClient: + class MockHttpSession(AsyncMock): + def get(self, *_args, **_kwargs): + return make_custom_mock_response(get_return_value, 400) + + http_session = MockHttpSession() + + client = AlephHttpClient(api_server="http://localhost") + client._http_session = http_session return client @@ -242,6 +303,68 @@ def post(self, *_args, **_kwargs): client = AuthenticatedAlephHttpClient( account=ethereum_account, api_server="http://localhost" ) - client.http_session = http_session + client._http_session = http_session return client + + +@pytest.fixture +def make_mock_aiohttp_session(): + def _make(mocked_json_response): + mock_response = AsyncMock() + mock_response.json.return_value = mocked_json_response + mock_response.raise_for_status.return_value = None + + session = MagicMock() + + get_cm = AsyncMock() + get_cm.__aenter__.return_value = mock_response + session.get.return_value = get_cm + + session_cm = AsyncMock() + session_cm.__aenter__.return_value = session + return session_cm + + return _make + + +# Constants needed for voucher tests +MOCK_ADDRESS = "0x1234567890123456789012345678901234567890" +MOCK_SOLANA_ADDRESS = "abcdefghijklmnopqrstuvwxyz123456789" +MOCK_METADATA_ID = "metadata123" +MOCK_VOUCHER_ID = "voucher123" +MOCK_METADATA = { + "name": "Test Voucher", + "description": "A test voucher", + "external_url": "https://example.com", + "image": "https://example.com/image.png", + "icon": "https://example.com/icon.png", + "attributes": [ + {"trait_type": "Test Trait", "value": "Test Value"}, + {"trait_type": "Numeric Trait", "value": "123", "display_type": "number"}, + ], +} + +MOCK_EVM_VOUCHER_DATA = [ + (MOCK_VOUCHER_ID, {"claimer": MOCK_ADDRESS, "metadata_id": MOCK_METADATA_ID}) +] + +MOCK_SOLANA_REGISTRY = { + "claimed_tickets": { + "solticket123": {"claimer": MOCK_SOLANA_ADDRESS, "batch_id": "batch123"} + }, + "batches": {"batch123": {"metadata_id": MOCK_METADATA_ID}}, +} + + +@pytest.fixture +def mock_post_response(): + mock_post = MagicMock() + mock_post.content = { + "nft_vouchers": { + MOCK_VOUCHER_ID: {"claimer": MOCK_ADDRESS, "metadata_id": MOCK_METADATA_ID} + } + } + posts_response = MagicMock() + posts_response.posts = [mock_post] + return posts_response diff --git a/tests/unit/services/__init__.py b/tests/unit/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/services/mocks.py b/tests/unit/services/mocks.py new file mode 100644 index 00000000..86f473b7 --- /dev/null +++ b/tests/unit/services/mocks.py @@ -0,0 +1,345 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from ..conftest import make_custom_mock_response + +FAKE_CRN_GPU_HASH = "abcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabcabca" +FAKE_CRN_GPU_ADDRESS = "0xBCABCABCABCABCABCABCABCABCABCABCABCABCAB" +FAKE_CRN_GPU_URL = "https://test.gpu.crn.com" + +FAKE_CRN_CONF_HASH = "defdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefdefd" +FAKE_CRN_CONF_ADDRESS = "0xDEfDEfDEfDEfDEfDEfDEfDEfDEfDEfDEfDEfDEfDEf" +FAKE_CRN_CONF_URL = "https://test.conf.crn" + +FAKE_CRN_BASIC_HASH = "aaaabbbbccccddddeeeeffff1111222233334444555566667777888899990000" +FAKE_CRN_BASIC_ADDRESS = "0xAAAABBBBCCCCDDDDEEEEFFFF1111222233334444" +FAKE_CRN_BASIC_URL = "https://test.basic.crn.com" + + +@pytest.fixture +def vm_status_v2(): + return { + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef": { + "networking": { + "ipv4_network": "192.168.0.0/24", + "host_ipv4": "192.168.0.1", + "ipv6_network": "2001:db8::/64", + "ipv6_ip": "2001:db8::1", + "mapped_ports": {}, + }, + "status": { + "defined_at": "2023-01-01T00:00:00Z", + "started_at": "2023-01-01T00:00:00Z", + "preparing_at": "2023-01-01T00:00:00Z", + "prepared_at": "2023-01-01T00:00:00Z", + "starting_at": "2023-01-01T00:00:00Z", + "stopping_at": "2023-01-01T00:00:00Z", + "stopped_at": "2023-01-01T00:00:00Z", + }, + "running": True, + } + } + + +@pytest.fixture +def vm_status_v1(): + return { + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef": { + "networking": {"ipv4": "192.168.0.1", "ipv6": "2001:db8::1"} + } + } + + +@pytest.fixture +def mock_crn_list(): + """Create a mock CRN list for testing.""" + return [ + { + "hash": FAKE_CRN_GPU_HASH, + "name": "Test GPU Instance", + "time": 1739525120.505, + "type": "compute", + "owner": FAKE_CRN_GPU_ADDRESS, + "score": 0.964502797686815, + "banner": "", + "locked": True, + "parent": FAKE_CRN_GPU_HASH, + "reward": FAKE_CRN_GPU_ADDRESS, + "status": "linked", + "address": FAKE_CRN_GPU_URL, + "manager": "", + "picture": "", + "authorized": "", + "description": "", + "performance": 0, + "multiaddress": "", + "score_updated": True, + "stream_reward": FAKE_CRN_GPU_ADDRESS, + "inactive_since": None, + "decentralization": 0.852680607762069, + "registration_url": "", + "terms_and_conditions": "", + "config_from_crn": True, + "debug_config_from_crn_at": "2025-06-18T12:09:03.843059+00:00", + "debug_config_from_crn_error": "None", + "debug_usage_from_crn_at": "2025-06-18T12:09:03.843059+00:00", + "usage_from_crn_error": "None", + "version": "1.6.0-rc1", + "payment_receiver_address": FAKE_CRN_GPU_ADDRESS, + "gpu_support": True, + "confidential_support": False, + "qemu_support": True, + "system_usage": { + "cpu": { + "count": 20, + "load_average": { + "load1": 0.357421875, + "load5": 0.31982421875, + "load15": 0.34912109375, + }, + "core_frequencies": {"min": 800, "max": 4280}, + }, + "mem": {"total_kB": 67219530, "available_kB": 61972037}, + "disk": {"total_kB": 1853812338, "available_kB": 1320664518}, + "period": { + "start_timestamp": "2025-06-18T12:09:00Z", + "duration_seconds": 60, + }, + "properties": { + "cpu": { + "architecture": "x86_64", + "vendor": "GenuineIntel", + "features": [], + } + }, + "gpu": { + "devices": [ + { + "vendor": "NVIDIA", + "model": "RTX 4000 ADA", + "device_name": "AD104GL [RTX 4000 SFF Ada Generation]", + "device_class": "0300", + "pci_host": "01:00.0", + "device_id": "10de:27b0", + "compatible": True, + } + ], + "available_devices": [ + { + "vendor": "NVIDIA", + "model": "RTX 4000 ADA", + "device_name": "AD104GL [RTX 4000 SFF Ada Generation]", + "device_class": "0300", + "pci_host": "01:00.0", + "device_id": "10de:27b0", + "compatible": True, + } + ], + }, + "active": True, + }, + "compatible_gpus": [ + { + "vendor": "NVIDIA", + "model": "RTX 4000 ADA", + "device_name": "AD104GL [RTX 4000 SFF Ada Generation]", + "device_class": "0300", + "pci_host": "01:00.0", + "device_id": "10de:27b0", + "compatible": True, + } + ], + "compatible_available_gpus": [ + { + "vendor": "NVIDIA", + "model": "RTX 4000 ADA", + "device_name": "AD104GL [RTX 4000 SFF Ada Generation]", + "device_class": "0300", + "pci_host": "01:00.0", + "device_id": "10de:27b0", + "compatible": True, + } + ], + "ipv6_check": {"host": True, "vm": True}, + }, + { + "hash": FAKE_CRN_CONF_HASH, + "name": "Test Conf CRN", + "time": 1739296606.021, + "type": "compute", + "owner": FAKE_CRN_CONF_ADDRESS, + "score": 0.964334395009276, + "banner": "", + "locked": False, + "parent": FAKE_CRN_CONF_HASH, + "reward": FAKE_CRN_CONF_ADDRESS, + "status": "linked", + "address": FAKE_CRN_CONF_URL, + "manager": "", + "picture": "", + "authorized": "", + "description": "", + "performance": 0, + "multiaddress": "", + "score_updated": False, + "stream_reward": FAKE_CRN_CONF_ADDRESS, + "inactive_since": None, + "decentralization": 0.994724704221032, + "registration_url": "", + "terms_and_conditions": "", + "config_from_crn": False, + "debug_config_from_crn_at": "2025-06-18T12:09:03.951298+00:00", + "debug_config_from_crn_error": "None", + "debug_usage_from_crn_at": "2025-06-18T12:09:03.951298+00:00", + "usage_from_crn_error": "None", + "version": "1.5.1", + "payment_receiver_address": FAKE_CRN_CONF_ADDRESS, + "gpu_support": False, + "confidential_support": True, + "qemu_support": True, + "system_usage": { + "cpu": { + "count": 224, + "load_average": { + "load1": 3.8466796875, + "load5": 3.9228515625, + "load15": 3.82080078125, + }, + "core_frequencies": {"min": 1500, "max": 2200}, + }, + "mem": {"total_kB": 807728145, "available_kB": 630166945}, + "disk": {"total_kB": 14971880235, "available_kB": 152975388}, + "period": { + "start_timestamp": "2025-06-18T12:09:00Z", + "duration_seconds": 60, + }, + "properties": { + "cpu": { + "architecture": "x86_64", + "vendor": "AuthenticAMD", + "features": ["sev", "sev_es"], + } + }, + "gpu": {"devices": [], "available_devices": []}, + "active": True, + }, + "compatible_gpus": [], + "compatible_available_gpus": [], + "ipv6_check": {"host": True, "vm": True}, + }, + { + "hash": FAKE_CRN_BASIC_HASH, + "name": "Test Basic CRN", + "time": 1687179700.242, + "type": "compute", + "owner": FAKE_CRN_BASIC_ADDRESS, + "score": 0.979808976368904, + "banner": FAKE_CRN_BASIC_HASH, + "locked": False, + "parent": FAKE_CRN_BASIC_HASH, + "reward": FAKE_CRN_BASIC_ADDRESS, + "status": "linked", + "address": FAKE_CRN_BASIC_URL, + "manager": FAKE_CRN_BASIC_ADDRESS, + "picture": FAKE_CRN_BASIC_HASH, + "authorized": "", + "description": "", + "performance": 0, + "multiaddress": "", + "score_updated": True, + "stream_reward": FAKE_CRN_BASIC_ADDRESS, + "inactive_since": None, + "decentralization": 0.93953628188216, + "registration_url": "", + "terms_and_conditions": "", + "config_from_crn": True, + "debug_config_from_crn_at": "2025-06-18T12:08:59.599676+00:00", + "debug_config_from_crn_error": "None", + "debug_usage_from_crn_at": "2025-06-18T12:08:59.599676+00:00", + "usage_from_crn_error": "None", + "version": "1.5.1", + "payment_receiver_address": FAKE_CRN_BASIC_ADDRESS, + "gpu_support": False, + "confidential_support": False, + "qemu_support": True, + "system_usage": { + "cpu": { + "count": 32, + "load_average": {"load1": 0, "load5": 0.01513671875, "load15": 0}, + "core_frequencies": {"min": 1200, "max": 3400}, + }, + "mem": {"total_kB": 270358832, "available_kB": 266152607}, + "disk": {"total_kB": 1005067972, "available_kB": 919488466}, + "period": { + "start_timestamp": "2025-06-18T12:09:00Z", + "duration_seconds": 60, + }, + "properties": { + "cpu": { + "architecture": "x86_64", + "vendor": "GenuineIntel", + "features": [], + } + }, + "gpu": {"devices": [], "available_devices": []}, + "active": True, + }, + "compatible_gpus": [], + "compatible_available_gpus": [], + "ipv6_check": {"host": True, "vm": False}, + }, + ] + + +def make_mock_aiohttp_session(mocked_json_response): + mock_response = AsyncMock() + mock_response.json.return_value = mocked_json_response + mock_response.raise_for_status.return_value = None + + session = MagicMock() + + session_cm = AsyncMock() + session_cm.__aenter__.return_value = session + + get_cm = AsyncMock() + get_cm.__aenter__.return_value = mock_response + + post_cm = AsyncMock() + post_cm.__aenter__.return_value = mock_response + + session.get = MagicMock(return_value=get_cm) + session.post = MagicMock(return_value=post_cm) + + return session_cm + + +def make_mock_get_active_vms_parametrized(v2_fails, expected_payload): + session = MagicMock() + + def get(url, *args, **kwargs): + mock_resp = None + if "/v2/about/executions/list" in url and v2_fails: + mock_resp = make_custom_mock_response(expected_payload, 404) + else: + mock_resp = make_custom_mock_response(expected_payload) + + mock_ctx = AsyncMock() + mock_ctx.__aenter__.return_value = mock_resp + return mock_ctx + + def post(url, *args, **kwargs): + if "/update" in url: + return make_custom_mock_response( + {"status": "ok", "msg": "VM not starting yet"}, 200 + ) + return None + + session.get = MagicMock(side_effect=get) + + session.post = MagicMock(side_effect=post) + + session_cm = AsyncMock() + session_cm.__aenter__.return_value = session + + return session_cm diff --git a/tests/unit/services/pricing_aggregate.json b/tests/unit/services/pricing_aggregate.json new file mode 100644 index 00000000..70f747ef --- /dev/null +++ b/tests/unit/services/pricing_aggregate.json @@ -0,0 +1,286 @@ +{ + "address": "0xFba561a84A537fCaa567bb7A2257e7142701ae2A", + "data": { + "pricing": { + "program": { + "price": { + "storage": { + "payg": "0.000000977", + "holding": "0.05", + "credit": "0.000000977" + }, + "compute_unit": { + "payg": "0.011", + "holding": "200", + "credit": "0.011" + } + }, + "tiers": [ + { + "id": "tier-1", + "compute_units": 1 + }, + { + "id": "tier-2", + "compute_units": 2 + }, + { + "id": "tier-3", + "compute_units": 4 + }, + { + "id": "tier-4", + "compute_units": 6 + }, + { + "id": "tier-5", + "compute_units": 8 + }, + { + "id": "tier-6", + "compute_units": 12 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 2048, + "memory_mib": 2048 + } + }, + "storage": { + "price": { + "storage": { + "holding": "0.333333333", + "credit": "0.333333333" + } + } + }, + "instance": { + "price": { + "storage": { + "payg": "0.000000977", + "holding": "0.05", + "credit": "0.000000977" + }, + "compute_unit": { + "payg": "0.055", + "holding": "1000", + "credit": "0.055" + } + }, + "tiers": [ + { + "id": "tier-1", + "compute_units": 1 + }, + { + "id": "tier-2", + "compute_units": 2 + }, + { + "id": "tier-3", + "compute_units": 4 + }, + { + "id": "tier-4", + "compute_units": 6 + }, + { + "id": "tier-5", + "compute_units": 8 + }, + { + "id": "tier-6", + "compute_units": 12 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 20480, + "memory_mib": 2048 + } + }, + "web3_hosting": { + "price": { + "fixed": 50, + "storage": { + "holding": "0.333333333" + } + } + }, + "program_persistent": { + "price": { + "storage": { + "payg": "0.000000977", + "holding": "0.05", + "credit": "0.000000977" + }, + "compute_unit": { + "payg": "0.055", + "holding": "1000", + "credit": "0.055" + } + }, + "tiers": [ + { + "id": "tier-1", + "compute_units": 1 + }, + { + "id": "tier-2", + "compute_units": 2 + }, + { + "id": "tier-3", + "compute_units": 4 + }, + { + "id": "tier-4", + "compute_units": 6 + }, + { + "id": "tier-5", + "compute_units": 8 + }, + { + "id": "tier-6", + "compute_units": 12 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 20480, + "memory_mib": 2048 + } + }, + "instance_gpu_premium": { + "price": { + "storage": { + "payg": "0.000000977", + "credit": "0.000000977" + }, + "compute_unit": { + "payg": "0.56", + "credit": "0.56" + } + }, + "tiers": [ + { + "id": "tier-1", + "vram": 81920, + "model": "A100", + "compute_units": 16 + }, + { + "id": "tier-2", + "vram": 81920, + "model": "H100", + "compute_units": 24 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 61440, + "memory_mib": 6144 + } + }, + "instance_confidential": { + "price": { + "storage": { + "payg": "0.000000977", + "holding": "0.05", + "credit": "0.000000977" + }, + "compute_unit": { + "payg": "0.11", + "holding": "2000", + "credit": "0.11" + } + }, + "tiers": [ + { + "id": "tier-1", + "compute_units": 1 + }, + { + "id": "tier-2", + "compute_units": 2 + }, + { + "id": "tier-3", + "compute_units": 4 + }, + { + "id": "tier-4", + "compute_units": 6 + }, + { + "id": "tier-5", + "compute_units": 8 + }, + { + "id": "tier-6", + "compute_units": 12 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 20480, + "memory_mib": 2048 + } + }, + "instance_gpu_standard": { + "price": { + "storage": { + "payg": "0.000000977", + "credit": "0.000000977" + }, + "compute_unit": { + "payg": "0.28", + "credit": "0.28" + } + }, + "tiers": [ + { + "id": "tier-1", + "vram": 20480, + "model": "RTX 4000 ADA", + "compute_units": 3 + }, + { + "id": "tier-2", + "vram": 24576, + "model": "RTX 3090", + "compute_units": 4 + }, + { + "id": "tier-3", + "vram": 24576, + "model": "RTX 4090", + "compute_units": 6 + }, + { + "id": "tier-3", + "vram": 32768, + "model": "RTX 5090", + "compute_units": 8 + }, + { + "id": "tier-4", + "vram": 49152, + "model": "L40S", + "compute_units": 12 + } + ], + "compute_unit": { + "vcpus": 1, + "disk_mib": 61440, + "memory_mib": 6144 + } + } + } + }, + "info": { + + } +} \ No newline at end of file diff --git a/tests/unit/services/test_authenticated_voucher.py b/tests/unit/services/test_authenticated_voucher.py new file mode 100644 index 00000000..bb83ea74 --- /dev/null +++ b/tests/unit/services/test_authenticated_voucher.py @@ -0,0 +1,111 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from aleph.sdk.client.services.authenticated_voucher import AuthenticatedVoucher + +from ..conftest import ( + MOCK_ADDRESS, + MOCK_METADATA, + MOCK_SOLANA_ADDRESS, + MOCK_SOLANA_REGISTRY, + MOCK_VOUCHER_ID, +) + + +def test_resolve_address_with_argument(): + client = MagicMock() + service = AuthenticatedVoucher(client=client) + assert service._resolve_address(address="custom-address") == "custom-address" + + +def test_resolve_address_with_account_fallback(): + mock_account = MagicMock() + mock_account.get_address.return_value = MOCK_ADDRESS + + client = MagicMock() + client.account = mock_account + + service = AuthenticatedVoucher(client=client) + assert service._resolve_address(address=None) == MOCK_ADDRESS + mock_account.get_address.assert_called_once() + + +def test_resolve_address_no_address_no_account(): + client = MagicMock() + client.account = None + + service = AuthenticatedVoucher(client=client) + + with pytest.raises( + ValueError, match="No address provided and no account configured" + ): + service._resolve_address(address=None) + + +@pytest.mark.asyncio +async def test_get_vouchers_fallback_to_account( + make_mock_aiohttp_session, mock_post_response +): + mock_account = MagicMock() + mock_account.get_address.return_value = MOCK_ADDRESS + + mock_client = MagicMock() + mock_client.account = mock_account + mock_client.get_posts = AsyncMock(return_value=mock_post_response) + + service = AuthenticatedVoucher(client=mock_client) + + metadata_session = make_mock_aiohttp_session(MOCK_METADATA) + + with patch("aiohttp.ClientSession", return_value=metadata_session): + vouchers = await service.get_vouchers() + + assert len(vouchers) == 1 + assert vouchers[0].name == MOCK_METADATA["name"] + mock_account.get_address.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_evm_vouchers_fallback_to_account( + make_mock_aiohttp_session, mock_post_response +): + mock_account = MagicMock() + mock_account.get_address.return_value = MOCK_ADDRESS + + mock_client = MagicMock() + mock_client.account = mock_account + mock_client.get_posts = AsyncMock(return_value=mock_post_response) + + service = AuthenticatedVoucher(client=mock_client) + + metadata_session = make_mock_aiohttp_session(MOCK_METADATA) + + with patch("aiohttp.ClientSession", return_value=metadata_session): + vouchers = await service.get_evm_vouchers() + + assert len(vouchers) == 1 + assert vouchers[0].id == MOCK_VOUCHER_ID + + +@pytest.mark.asyncio +async def test_get_solana_vouchers_fallback_to_account(make_mock_aiohttp_session): + mock_account = MagicMock() + mock_account.get_address.return_value = MOCK_SOLANA_ADDRESS + + mock_client = MagicMock() + mock_client.account = mock_account + + service = AuthenticatedVoucher(client=mock_client) + + registry_session = make_mock_aiohttp_session(MOCK_SOLANA_REGISTRY) + metadata_session = make_mock_aiohttp_session(MOCK_METADATA) + + with patch( + "aiohttp.ClientSession", side_effect=[registry_session, metadata_session] + ): + vouchers = await service.get_solana_vouchers() + + assert len(vouchers) == 1 + assert vouchers[0].id == "solticket123" + assert vouchers[0].name == MOCK_METADATA["name"] diff --git a/tests/unit/services/test_authorizations.py b/tests/unit/services/test_authorizations.py new file mode 100644 index 00000000..7ab2b7ee --- /dev/null +++ b/tests/unit/services/test_authorizations.py @@ -0,0 +1,562 @@ +""" +Tests for authorization methods in AlephClient. +""" + +from typing import Any, Dict, Iterable, Optional, Tuple + +import pytest +from aleph_message.models import AggregateMessage, Chain, MessageType +from aleph_message.status import MessageStatus + +from aleph.sdk.client.abstract import AuthenticatedAlephClient +from aleph.sdk.types import ( + Account, + Authorization, + AuthorizationBuilder, + SecurityAggregateContent, +) + + +class FakeAccount: + """Minimal fake account for testing.""" + + CHAIN = "ETH" + CURVE = "secp256k1" + + def __init__(self, address: str = "0xTestAddress1234567890123456789012345678"): + self._address = address + + async def sign_message(self, message: Dict) -> Dict: + message["signature"] = "0x" + "ab" * 65 + return message + + async def sign_raw(self, buffer: bytes) -> bytes: + return b"fake_signature" + + def get_address(self) -> str: + return self._address + + def get_public_key(self) -> str: + return "0x" + "cd" * 33 + + +class MockAlephClient(AuthenticatedAlephClient): + """ + A fake authenticated client that maintains an in-memory aggregate store. + Aggregates are dictionaries that get merged/updated with each create_aggregate call. + """ + + def __init__(self, account: Optional[Account] = None): + self.account = account or FakeAccount() + # Storage: {address: {key: content}} + self._aggregates: Dict[str, Dict[str, Any]] = {} + + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Any]: + """Fetch a single aggregate by address and key.""" + if address not in self._aggregates: + return {"authorizations": []} + return self._aggregates[address].get(key, {"authorizations": []}) + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None + ) -> Dict[str, Dict]: + """Fetch multiple aggregates.""" + if address not in self._aggregates: + return {} + if keys is None: + return self._aggregates[address] + return {k: v for k, v in self._aggregates[address].items() if k in keys} + + async def create_aggregate( + self, + key: str, + content: Dict[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Create/update an aggregate. Merges content into existing aggregate. + """ + address = address or self.account.get_address() + + if address not in self._aggregates: + self._aggregates[address] = {} + + # Aggregates merge content (like a dict update) + if key in self._aggregates[address]: + self._aggregates[address][key].update(content) + else: + self._aggregates[address][key] = content + + # Return a minimal mock message + mock_message = AggregateMessage.model_validate( + { + "item_hash": "44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a", + "type": "AGGREGATE", + "chain": "ETH", + "sender": address, + "signature": "0x" + "ab" * 65, + "item_type": "inline", + "item_content": "{}", + "content": { + "key": key, + "address": address, + "content": content, + "time": 0, + }, + "time": 0, + "channel": channel or "TEST", + } + ) + return mock_message, MessageStatus.PROCESSED + + # Stub implementations for abstract methods we don't need + async def create_post(self, *args, **kwargs): + raise NotImplementedError + + async def create_store(self, *args, **kwargs): + raise NotImplementedError + + async def create_program(self, *args, **kwargs): + raise NotImplementedError + + async def create_instance(self, *args, **kwargs): + raise NotImplementedError + + async def forget(self, *args, **kwargs): + raise NotImplementedError + + async def submit(self, *args, **kwargs): + raise NotImplementedError + + async def get_posts(self, *args, **kwargs): + raise NotImplementedError + + async def download_file(self, *args, **kwargs): + raise NotImplementedError + + async def download_file_to_path(self, *args, **kwargs): + raise NotImplementedError + + async def get_messages(self, *args, **kwargs): + raise NotImplementedError + + async def get_message(self, *args, **kwargs): + raise NotImplementedError + + def watch_messages(self, *args, **kwargs): + raise NotImplementedError + + def get_estimated_price(self, *args, **kwargs): + raise NotImplementedError + + def get_program_price(self, *args, **kwargs): + raise NotImplementedError + + +# Fixtures +@pytest.fixture +def mock_client() -> MockAlephClient: + """Create a fresh fake client for each test.""" + return MockAlephClient() + + +@pytest.fixture +def mock_client_with_existing_auth() -> MockAlephClient: + """Create a fake client with pre-existing authorizations.""" + client = MockAlephClient() + client._aggregates[client.account.get_address()] = { + "security": { + "authorizations": [ + { + "address": "0xExistingAddress123456789012345678901234", + "chain": "ETH", + "channels": ["existing_channel"], + "types": ["POST"], + "post_types": [], + "aggregate_keys": [], + } + ] + } + } + return client + + +# Tests for get_authorizations +class TestGetAuthorizations: + @pytest.mark.asyncio + async def test_get_authorizations_empty(self, mock_client: MockAlephClient): + """When no authorizations exist, returns empty list.""" + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert authorizations == [] + + @pytest.mark.asyncio + async def test_get_authorizations_returns_existing( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Returns existing authorizations from aggregate store.""" + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + + assert len(authorizations) == 1 + assert authorizations[0].address == "0xExistingAddress123456789012345678901234" + assert authorizations[0].chain == Chain.ETH + assert authorizations[0].channels == ["existing_channel"] + + +# Tests for update_all_authorizations +class TestUpdateAllAuthorizations: + @pytest.mark.asyncio + async def test_update_replaces_all_authorizations( + self, mock_client: MockAlephClient + ): + """update_all_authorizations replaces the entire authorization list.""" + auth1 = Authorization(address="0xAddress1111111111111111111111111111111111") + auth2 = Authorization(address="0xAddress2222222222222222222222222222222222") + + await mock_client.update_all_authorizations([auth1, auth2]) + + # Verify stored content + stored = mock_client._aggregates[mock_client.account.get_address()]["security"] + assert len(stored["authorizations"]) == 2 + + @pytest.mark.asyncio + async def test_update_with_empty_list_clears_authorizations( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Passing an empty list removes all authorizations.""" + await mock_client_with_existing_auth.update_all_authorizations([]) + + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + assert authorizations == [] + + @pytest.mark.asyncio + async def test_update_preserves_authorization_fields( + self, mock_client: MockAlephClient + ): + """All authorization fields are preserved when storing.""" + auth = Authorization( + address="0xFullAuth111111111111111111111111111111111", + chain=Chain.ETH, + channels=["channel1", "channel2"], + types=[MessageType.post, MessageType.aggregate], + post_types=["blog", "comment"], + aggregate_keys=["settings"], + ) + + await mock_client.update_all_authorizations([auth]) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + retrieved = authorizations[0] + assert retrieved.address == auth.address + assert retrieved.chain == Chain.ETH + assert retrieved.channels == ["channel1", "channel2"] + assert MessageType.post in retrieved.types + assert "blog" in retrieved.post_types + + +# Tests for add_authorization +class TestAddAuthorization: + @pytest.mark.asyncio + async def test_add_to_empty(self, mock_client: MockAlephClient): + """Adding authorization when none exist.""" + auth = Authorization(address="0xNewAddress1111111111111111111111111111111") + + await mock_client.add_authorization(auth) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert ( + authorizations[0].address == "0xNewAddress1111111111111111111111111111111" + ) + + @pytest.mark.asyncio + async def test_add_appends_to_existing( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Adding authorization appends to existing list.""" + new_auth = Authorization( + address="0xNewAddress2222222222222222222222222222222", + channels=["new_channel"], + ) + + await mock_client_with_existing_auth.add_authorization(new_auth) + + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + assert len(authorizations) == 2 + addresses = [a.address for a in authorizations] + assert "0xExistingAddress123456789012345678901234" in addresses + assert "0xNewAddress2222222222222222222222222222222" in addresses + + @pytest.mark.asyncio + async def test_add_multiple_authorizations_sequentially( + self, mock_client: MockAlephClient + ): + """Adding multiple authorizations one by one.""" + auth1 = Authorization(address="0xFirst11111111111111111111111111111111111") + auth2 = Authorization(address="0xSecond2222222222222222222222222222222222") + auth3 = Authorization(address="0xThird33333333333333333333333333333333333") + + await mock_client.add_authorization(auth1) + await mock_client.add_authorization(auth2) + await mock_client.add_authorization(auth3) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 3 + + +# Tests for revoke_all_authorizations +class TestRevokeAllAuthorizations: + @pytest.mark.asyncio + async def test_revoke_removes_matching_address( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Revoking removes all authorizations for the specified address.""" + await mock_client_with_existing_auth.revoke_all_authorizations( + "0xExistingAddress123456789012345678901234" + ) + + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + assert len(authorizations) == 0 + + @pytest.mark.asyncio + async def test_revoke_keeps_other_addresses(self, mock_client: MockAlephClient): + """Revoking only removes authorizations for the specified address.""" + auth1 = Authorization(address="0xToRevoke111111111111111111111111111111111") + auth2 = Authorization(address="0xToKeep22222222222222222222222222222222222") + auth3 = Authorization( + address="0xToRevoke111111111111111111111111111111111" + ) # Duplicate + + await mock_client.update_all_authorizations([auth1, auth2, auth3]) + + await mock_client.revoke_all_authorizations( + "0xToRevoke111111111111111111111111111111111" + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert ( + authorizations[0].address == "0xToKeep22222222222222222222222222222222222" + ) + + @pytest.mark.asyncio + async def test_revoke_nonexistent_address_is_noop( + self, mock_client: MockAlephClient + ): + """Revoking an address that doesn't exist does nothing.""" + auth = Authorization(address="0xExisting1111111111111111111111111111111111") + await mock_client.add_authorization(auth) + + await mock_client.revoke_all_authorizations( + "0xNonExistent22222222222222222222222222222" + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + + @pytest.mark.asyncio + async def test_revoke_from_empty_is_noop(self, mock_client: MockAlephClient): + """Revoking when no authorizations exist doesn't error.""" + await mock_client.revoke_all_authorizations( + "0xAnyAddress111111111111111111111111111111111" + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert authorizations == [] + + +# Integration tests - full workflows +class TestAuthorizationWorkflows: + @pytest.mark.asyncio + async def test_full_lifecycle(self, mock_client: MockAlephClient): + """Test complete authorization lifecycle: add, verify, revoke.""" + delegate_address = "0xDelegate111111111111111111111111111111111" + + # Initially empty + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 0 + + # Add authorization + auth = Authorization( + address=delegate_address, + channels=["MY_APP"], + types=[MessageType.post], + ) + await mock_client.add_authorization(auth) + + # Verify it exists + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert authorizations[0].address == delegate_address + assert "MY_APP" in authorizations[0].channels + + # Revoke + await mock_client.revoke_all_authorizations(delegate_address) + + # Verify it's gone + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 0 + + @pytest.mark.asyncio + async def test_multiple_delegates_workflow(self, mock_client: MockAlephClient): + """Test managing authorizations for multiple delegate addresses.""" + delegate1 = "0xDelegate1111111111111111111111111111111111" + delegate2 = "0xDelegate2222222222222222222222222222222222" + + # Add two delegates + await mock_client.add_authorization( + Authorization(address=delegate1, channels=["channel_a"]) + ) + await mock_client.add_authorization( + Authorization(address=delegate2, channels=["channel_b"]) + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 2 + + # Revoke first delegate + await mock_client.revoke_all_authorizations(delegate1) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert authorizations[0].address == delegate2 + + @pytest.mark.asyncio + async def test_replace_all_authorizations(self, mock_client: MockAlephClient): + """Test replacing all authorizations at once.""" + # Add initial authorizations + await mock_client.add_authorization( + Authorization(address="0xOld111111111111111111111111111111111111111") + ) + await mock_client.add_authorization( + Authorization(address="0xOld222222222222222222222222222222222222222") + ) + + # Replace with new set + new_auths = [ + Authorization(address="0xNew111111111111111111111111111111111111111"), + Authorization(address="0xNew222222222222222222222222222222222222222"), + Authorization(address="0xNew333333333333333333333333333333333333333"), + ] + await mock_client.update_all_authorizations(new_auths) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 3 + addresses = {a.address for a in authorizations} + assert "0xOld111111111111111111111111111111111111111" not in addresses + assert "0xNew111111111111111111111111111111111111111" in addresses + + +# Model tests +class TestAuthorizationModel: + def test_minimal_authorization(self): + """Authorization can be created with just an address.""" + auth = Authorization(address="0x1234567890123456789012345678901234567890") + assert auth.address == "0x1234567890123456789012345678901234567890" + assert auth.chain is None + assert auth.channels == [] + assert auth.types == [] + + def test_full_authorization(self): + """Authorization with all fields set.""" + auth = Authorization( + address="0x1234567890123456789012345678901234567890", + chain=Chain.ETH, + channels=["ch1", "ch2"], + types=[MessageType.post, MessageType.store], + post_types=["blog"], + aggregate_keys=["settings"], + ) + assert auth.chain == Chain.ETH + assert len(auth.channels) == 2 + assert len(auth.types) == 2 + + def test_security_aggregate_serialization(self): + """SecurityAggregateContent serializes correctly.""" + auth = Authorization( + address="0x1234567890123456789012345678901234567890", + channels=["test"], + ) + content = SecurityAggregateContent(authorizations=[auth]) + dumped = content.model_dump() + + assert "authorizations" in dumped + assert len(dumped["authorizations"]) == 1 + assert dumped["authorizations"][0]["address"] == auth.address + + +class TestAuthorizationBuilder: + def test_authorization_builder_only_address(self): + """Test the AuthorizationBuilder.""" + auth = AuthorizationBuilder( + address="0x1234567890123456789012345678901234567890" + ).build() + assert auth.address == "0x1234567890123456789012345678901234567890" + assert auth.chain is None + assert auth.channels == [] + assert auth.types == [] + assert auth.post_types == [] + assert auth.aggregate_keys == [] + + def test_authorization_builder(self): + """Test the AuthorizationBuilder with a detailed configuration.""" + sample_authorization = Authorization( + address="0xFullAuth111111111111111111111111111111111", + chain=Chain.ETH, + channels=["channel1", "channel2"], + types=[MessageType.post, MessageType.aggregate], + post_types=["blog", "comment"], + aggregate_keys=["settings"], + ) + + auth = AuthorizationBuilder(address=sample_authorization.address).chain( + sample_authorization.chain + ) + for channel in sample_authorization.channels: + auth = auth.channel(channel) + for message_type in sample_authorization.types: + auth = auth.message_type(message_type) + for post_type in sample_authorization.post_types: + auth = auth.post_type(post_type) + for aggregate_key in sample_authorization.aggregate_keys: + auth = auth.aggregate_key(aggregate_key) + auth = auth.build() + + assert auth == sample_authorization diff --git a/tests/unit/services/test_base_service.py b/tests/unit/services/test_base_service.py new file mode 100644 index 00000000..2c1304ec --- /dev/null +++ b/tests/unit/services/test_base_service.py @@ -0,0 +1,46 @@ +from typing import Optional +from unittest.mock import AsyncMock + +import pytest +from pydantic import BaseModel + +from aleph.sdk.client.services.base import AggregateConfig, BaseService + + +class DummyModel(BaseModel): + foo: str + bar: Optional[int] + + +class DummyService(BaseService[DummyModel]): + aggregate_key = "dummy_key" + model_cls = DummyModel + + +@pytest.mark.asyncio +async def test_get_config_with_data(): + mock_client = AsyncMock() + mock_data = {"foo": "hello", "bar": 123} + mock_client.get_aggregate.return_value = mock_data + + service = DummyService(mock_client) + + result = await service.get_config("0xSOME_ADDRESS") + + assert isinstance(result, AggregateConfig) + assert result.data is not None + assert isinstance(result.data[0], DummyModel) + assert result.data[0].foo == "hello" + assert result.data[0].bar == 123 + + +@pytest.mark.asyncio +async def test_get_config_with_no_data(): + mock_client = AsyncMock() + mock_client.get_aggregate.return_value = None + + service = DummyService(mock_client) + result = await service.get_config("0xSOME_ADDRESS") + + assert isinstance(result, AggregateConfig) + assert result.data is None diff --git a/tests/unit/services/test_pricing.py b/tests/unit/services/test_pricing.py new file mode 100644 index 00000000..ab6f7981 --- /dev/null +++ b/tests/unit/services/test_pricing.py @@ -0,0 +1,212 @@ +import json +from decimal import Decimal +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from aleph.sdk.client.http import AlephHttpClient +from aleph.sdk.client.services.pricing import ( + PAYG_GROUP, + PRICING_GROUPS, + GroupEntity, + Price, + Pricing, + PricingEntity, + PricingModel, + PricingPerEntity, +) + + +@pytest.fixture +def pricing_aggregate(): + """Load the pricing aggregate JSON file for testing.""" + json_path = Path(__file__).parent / "pricing_aggregate.json" + with open(json_path, "r") as f: + data = json.load(f) + return data + + +@pytest.fixture +def mock_client(pricing_aggregate): + """Create a real client with mocked HTTP responses.""" + # Create a mock response for the http session get method + mock_response = AsyncMock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = pricing_aggregate + + # Create an async context manager for the mock response + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_response + + # Create a mock HTTP session + mock_session = AsyncMock() + mock_session.get = MagicMock(return_value=mock_context) + + client = AlephHttpClient(api_server="http://localhost") + client._http_session = mock_session + + return client + + +@pytest.mark.asyncio +async def test_get_pricing_aggregate(mock_client): + """Test fetching the pricing aggregate data.""" + pricing_service = Pricing(mock_client) + result = await pricing_service.get_pricing_aggregate() + + # Check the result is a PricingModel + assert isinstance(result, PricingModel) + + assert PricingEntity.STORAGE in result + assert PricingEntity.PROGRAM in result + assert PricingEntity.INSTANCE in result + + storage_entity = result[PricingEntity.STORAGE] + assert isinstance(storage_entity, PricingPerEntity) + assert "storage" in storage_entity.price + storage_price = storage_entity.price["storage"] + assert isinstance(storage_price, Price) # Add type assertion for mypy + assert storage_price.holding == Decimal("0.333333333") + + # Check program entity has correct compute unit details + program_entity = result[PricingEntity.PROGRAM] + assert isinstance(program_entity, PricingPerEntity) + assert program_entity.compute_unit is not None # Ensure compute_unit is not None + assert program_entity.compute_unit.vcpus == 1 + assert program_entity.compute_unit.memory_mib == 2048 + assert program_entity.compute_unit.disk_mib == 2048 + + # Check tiers in instance entity + instance_entity = result[PricingEntity.INSTANCE] + assert instance_entity.tiers is not None # Ensure tiers is not None + assert len(instance_entity.tiers) == 6 + assert instance_entity.tiers[0].id == "tier-1" + assert instance_entity.tiers[0].compute_units == 1 + + +@pytest.mark.asyncio +async def test_get_pricing_for_services(mock_client): + """Test fetching pricing for specific services.""" + pricing_service = Pricing(mock_client) + + # Test Case 1: Get pricing for storage and program services + services = [PricingEntity.STORAGE, PricingEntity.PROGRAM] + result = await pricing_service.get_pricing_for_services(services) + + # Check the result contains only the requested entities + assert len(result) == 2 + assert PricingEntity.STORAGE in result + assert PricingEntity.PROGRAM in result + assert PricingEntity.INSTANCE not in result + + # Verify specific pricing data + storage_price = result[PricingEntity.STORAGE].price["storage"] + assert isinstance(storage_price, Price) # Ensure it's a Price object + assert storage_price.holding == Decimal("0.333333333") + + compute_price = result[PricingEntity.PROGRAM].price["compute_unit"] + assert isinstance(compute_price, Price) # Ensure it's a Price object + assert compute_price.payg == Decimal("0.011") + assert compute_price.holding == Decimal("200") + + # Test Case 2: Using pre-fetched pricing aggregate + pricing_info = await pricing_service.get_pricing_aggregate() + result2 = await pricing_service.get_pricing_for_services(services, pricing_info) + + # Results should be the same + assert result[PricingEntity.STORAGE].price == result2[PricingEntity.STORAGE].price + assert result[PricingEntity.PROGRAM].price == result2[PricingEntity.PROGRAM].price + + # Test Case 3: Empty services list + empty_result = await pricing_service.get_pricing_for_services([]) + assert isinstance(empty_result, dict) + assert len(empty_result) == 0 + + # Test Case 4: Web3 hosting service + web3_result = await pricing_service.get_pricing_for_services( + [PricingEntity.WEB3_HOSTING] + ) + assert len(web3_result) == 1 + assert PricingEntity.WEB3_HOSTING in web3_result + assert web3_result[PricingEntity.WEB3_HOSTING].price["fixed"] == Decimal("50") + + # Test Case 5: GPU services have specific properties + gpu_services = [ + PricingEntity.INSTANCE_GPU_STANDARD, + PricingEntity.INSTANCE_GPU_PREMIUM, + ] + gpu_result = await pricing_service.get_pricing_for_services(gpu_services) + assert len(gpu_result) == 2 + # Check GPU models are present + standard_tiers = gpu_result[PricingEntity.INSTANCE_GPU_STANDARD].tiers + premium_tiers = gpu_result[PricingEntity.INSTANCE_GPU_PREMIUM].tiers + assert standard_tiers is not None + assert premium_tiers is not None + assert standard_tiers[0].model == "RTX 4000 ADA" + assert premium_tiers[1].model == "H100" + + +@pytest.mark.asyncio +async def test_get_pricing_for_gpu_services(mock_client): + """Test fetching pricing for GPU services.""" + pricing_service = Pricing(mock_client) + + # Test with GPU services + gpu_services = [ + PricingEntity.INSTANCE_GPU_STANDARD, + PricingEntity.INSTANCE_GPU_PREMIUM, + ] + result = await pricing_service.get_pricing_for_services(gpu_services) + + # Check that both GPU services are returned + assert len(result) == 2 + assert PricingEntity.INSTANCE_GPU_STANDARD in result + assert PricingEntity.INSTANCE_GPU_PREMIUM in result + + # Verify GPU standard pricing and details + gpu_standard = result[PricingEntity.INSTANCE_GPU_STANDARD] + compute_unit_price = gpu_standard.price["compute_unit"] + assert isinstance(compute_unit_price, Price) + assert compute_unit_price.payg == Decimal("0.28") + + standard_tiers = gpu_standard.tiers + assert standard_tiers is not None + assert len(standard_tiers) == 5 + assert standard_tiers[0].model == "RTX 4000 ADA" + assert standard_tiers[0].vram == 20480 + + # Verify GPU premium pricing and details + gpu_premium = result[PricingEntity.INSTANCE_GPU_PREMIUM] + premium_compute_price = gpu_premium.price["compute_unit"] + assert isinstance(premium_compute_price, Price) + assert premium_compute_price.payg == Decimal("0.56") + + premium_tiers = gpu_premium.tiers + assert premium_tiers is not None + assert len(premium_tiers) == 2 + assert premium_tiers[1].model == "H100" + assert premium_tiers[1].vram == 81920 + + +@pytest.mark.asyncio +async def test_pricing_groups(): + """Test the pricing groups constants.""" + # Check that all pricing entities are covered in PRICING_GROUPS + all_entities = set() + for group_entities in PRICING_GROUPS.values(): + for entity in group_entities: + all_entities.add(entity) + + # All PricingEntity values should be in some group + for entity in PricingEntity: + assert entity in all_entities + + # Check ALL group contains all entities + assert set(PRICING_GROUPS[GroupEntity.ALL]) == set(PricingEntity) + + # Check PAYG_GROUP contains expected entities + assert PricingEntity.INSTANCE in PAYG_GROUP + assert PricingEntity.INSTANCE_CONFIDENTIAL in PAYG_GROUP + assert PricingEntity.INSTANCE_GPU_STANDARD in PAYG_GROUP + assert PricingEntity.INSTANCE_GPU_PREMIUM in PAYG_GROUP diff --git a/tests/unit/services/test_settings.py b/tests/unit/services/test_settings.py new file mode 100644 index 00000000..9ac09d9e --- /dev/null +++ b/tests/unit/services/test_settings.py @@ -0,0 +1,200 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from aleph.sdk import AlephHttpClient +from aleph.sdk.client.services.settings import NetworkSettingsModel, Settings + + +@pytest.fixture +def mock_settings_aggregate_response(): + return { + "compatible_gpus": [ + { + "name": "AD102GL [L40S]", + "model": "L40S", + "vendor": "NVIDIA", + "device_id": "10de:26b9", + }, + { + "name": "GB202 [GeForce RTX 5090]", + "model": "RTX 5090", + "vendor": "NVIDIA", + "device_id": "10de:2685", + }, + { + "name": "GB202 [GeForce RTX 5090 D]", + "model": "RTX 5090", + "vendor": "NVIDIA", + "device_id": "10de:2687", + }, + { + "name": "AD102 [GeForce RTX 4090]", + "model": "RTX 4090", + "vendor": "NVIDIA", + "device_id": "10de:2684", + }, + { + "name": "AD102 [GeForce RTX 4090 D]", + "model": "RTX 4090", + "vendor": "NVIDIA", + "device_id": "10de:2685", + }, + { + "name": "GA102 [GeForce RTX 3090]", + "model": "RTX 3090", + "vendor": "NVIDIA", + "device_id": "10de:2204", + }, + { + "name": "GA102 [GeForce RTX 3090 Ti]", + "model": "RTX 3090", + "vendor": "NVIDIA", + "device_id": "10de:2203", + }, + { + "name": "AD104GL [RTX 4000 SFF Ada Generation]", + "model": "RTX 4000 ADA", + "vendor": "NVIDIA", + "device_id": "10de:27b0", + }, + { + "name": "AD104GL [RTX 4000 Ada Generation]", + "model": "RTX 4000 ADA", + "vendor": "NVIDIA", + "device_id": "10de:27b2", + }, + { + "name": "GA102GL [RTX A5000]", + "model": "RTX A5000", + "vendor": "NVIDIA", + "device_id": "10de:2231", + }, + { + "name": "GA102GL [RTX A6000]", + "model": "RTX A6000", + "vendor": "NVIDIA", + "device_id": "10de:2230", + }, + { + "name": "GH100 [H100]", + "model": "H100", + "vendor": "NVIDIA", + "device_id": "10de:2336", + }, + { + "name": "GH100 [H100 NVSwitch]", + "model": "H100", + "vendor": "NVIDIA", + "device_id": "10de:22a3", + }, + { + "name": "GH100 [H100 CNX]", + "model": "H100", + "vendor": "NVIDIA", + "device_id": "10de:2313", + }, + { + "name": "GH100 [H100 SXM5 80GB]", + "model": "H100", + "vendor": "NVIDIA", + "device_id": "10de:2330", + }, + { + "name": "GH100 [H100 PCIe]", + "model": "H100", + "vendor": "NVIDIA", + "device_id": "10de:2331", + }, + { + "name": "GA100", + "model": "A100", + "vendor": "NVIDIA", + "device_id": "10de:2080", + }, + { + "name": "GA100", + "model": "A100", + "vendor": "NVIDIA", + "device_id": "10de:2081", + }, + { + "name": "GA100 [A100 SXM4 80GB]", + "model": "A100", + "vendor": "NVIDIA", + "device_id": "10de:20b2", + }, + { + "name": "GA100 [A100 PCIe 80GB]", + "model": "A100", + "vendor": "NVIDIA", + "device_id": "10de:20b5", + }, + { + "name": "GA100 [A100X]", + "model": "A100", + "vendor": "NVIDIA", + "device_id": "10de:20b8", + }, + { + "name": "GH100 [H200 SXM 141GB]", + "model": "H200", + "vendor": "NVIDIA", + "device_id": "10de:2335", + }, + { + "name": "GH100 [H200 NVL]", + "model": "H200", + "vendor": "NVIDIA", + "device_id": "10de:233b", + }, + { + "name": "AD102GL [RTX 6000 ADA]", + "model": "RTX 6000 ADA", + "vendor": "NVIDIA", + "device_id": "10de:26b1", + }, + ], + "last_crn_version": "1.7.2", + "community_wallet_address": "0x5aBd3258C5492fD378EBC2e0017416E199e5Da56", + "community_wallet_timestamp": 1739996239, + } + + +@pytest.mark.asyncio +async def test_get_settings_aggregate( + make_mock_aiohttp_session, mock_settings_aggregate_response +): + client = AlephHttpClient(api_server="http://localhost") + + # Properly mock the fetch_aggregate method using monkeypatch + client._http_session = MagicMock() + monkeypatch = AsyncMock(return_value=mock_settings_aggregate_response) + setattr(client, "get_aggregate", monkeypatch) + + settings_service = Settings(client) + result = await settings_service.get_settings_aggregate() + + assert isinstance(result, NetworkSettingsModel) + assert len(result.compatible_gpus) == 24 # We have 24 GPUs in the mock data + + rtx4000_gpu = next( + gpu for gpu in result.compatible_gpus if gpu.device_id == "10de:27b0" + ) + assert rtx4000_gpu.name == "AD104GL [RTX 4000 SFF Ada Generation]" + assert rtx4000_gpu.model == "RTX 4000 ADA" + assert rtx4000_gpu.vendor == "NVIDIA" + + assert result.last_crn_version == "1.7.2" + assert ( + result.community_wallet_address == "0x5aBd3258C5492fD378EBC2e0017416E199e5Da56" + ) + assert result.community_wallet_timestamp == 1739996239 + + # Verify that fetch_aggregate was called with the correct parameters + assert monkeypatch.call_count == 1 + assert ( + monkeypatch.call_args.kwargs["address"] + == "0xFba561a84A537fCaa567bb7A2257e7142701ae2A" + ) + assert monkeypatch.call_args.kwargs["key"] == "settings" diff --git a/tests/unit/services/test_voucher.py b/tests/unit/services/test_voucher.py new file mode 100644 index 00000000..7519ad19 --- /dev/null +++ b/tests/unit/services/test_voucher.py @@ -0,0 +1,120 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aleph_message.models import Chain + +from aleph.sdk.client.http import AlephHttpClient +from aleph.sdk.client.services.voucher import Vouchers + +from ..conftest import ( + MOCK_ADDRESS, + MOCK_METADATA, + MOCK_SOLANA_ADDRESS, + MOCK_SOLANA_REGISTRY, + MOCK_VOUCHER_ID, +) + + +@pytest.mark.asyncio +async def test_get_evm_vouchers(mock_post_response, make_mock_aiohttp_session): + client = AlephHttpClient(api_server="http://localhost") + + # Patch only the get_posts who is used to fetch voucher update for EVM + with patch.object(client, "get_posts", AsyncMock(return_value=mock_post_response)): + voucher_service = Vouchers(client=client) + + session = make_mock_aiohttp_session(MOCK_METADATA) + + # Here we patch the client sessions who gonna fetch the metdata of the NFT + with patch("aiohttp.ClientSession", return_value=session): + vouchers = await voucher_service.get_evm_vouchers(MOCK_ADDRESS) + + assert len(vouchers) == 1 + assert vouchers[0].id == MOCK_VOUCHER_ID + assert vouchers[0].name == MOCK_METADATA["name"] + + +@pytest.mark.asyncio +async def test_get_solana_vouchers(make_mock_aiohttp_session): + client = AlephHttpClient(api_server="http://localhost") + voucher_service = Vouchers(client=client) + + registry_session = make_mock_aiohttp_session(MOCK_SOLANA_REGISTRY) + metadata_session = make_mock_aiohttp_session(MOCK_METADATA) + + # Here we patch the fetch of the registry made on + # https://api.claim.twentysix.cloud/v1/registry/solanna + # and we also patch the fetch of the metadata + # https://claim.twentysix.cloud/sbt/metadata/{}.json + with patch( + "aiohttp.ClientSession", side_effect=[registry_session, metadata_session] + ): + vouchers = await voucher_service.get_solana_vouchers(MOCK_SOLANA_ADDRESS) + + assert len(vouchers) == 1 + assert vouchers[0].id == "solticket123" + assert vouchers[0].name == MOCK_METADATA["name"] + + +@pytest.mark.asyncio +async def test_fetch_vouchers_by_chain_for_evm( + mock_post_response, make_mock_aiohttp_session +): + client = AlephHttpClient(api_server="http://localhost") + with patch.object(client, "get_posts", AsyncMock(return_value=mock_post_response)): + voucher_service = Vouchers(client=client) + + metadata_session = make_mock_aiohttp_session(MOCK_METADATA) + with patch("aiohttp.ClientSession", return_value=metadata_session): + vouchers = await voucher_service.fetch_vouchers_by_chain( + Chain.ETH, MOCK_ADDRESS + ) + + assert len(vouchers) == 1 + assert vouchers[0].id == "voucher123" + + +@pytest.mark.asyncio +async def test_fetch_vouchers_by_chain_for_solana(make_mock_aiohttp_session): + mock_client = MagicMock() + voucher_service = Vouchers(client=mock_client) + + registry_session = make_mock_aiohttp_session(MOCK_SOLANA_REGISTRY) + metadata_session = make_mock_aiohttp_session(MOCK_METADATA) + + with patch( + "aiohttp.ClientSession", side_effect=[registry_session, metadata_session] + ): + vouchers = await voucher_service.fetch_vouchers_by_chain( + Chain.SOL, MOCK_SOLANA_ADDRESS + ) + + assert len(vouchers) == 1 + assert vouchers[0].id == "solticket123" + + +@pytest.mark.asyncio +async def test_get_vouchers_detects_chain( + make_mock_aiohttp_session, mock_post_response +): + client = AlephHttpClient(api_server="http://localhost") + with patch.object(client, "get_posts", AsyncMock(return_value=mock_post_response)): + voucher_service = Vouchers(client=client) + + # EVM + metadata_session = make_mock_aiohttp_session(MOCK_METADATA) + with patch("aiohttp.ClientSession", return_value=metadata_session): + vouchers = await voucher_service.get_vouchers(MOCK_ADDRESS) + assert len(vouchers) == 1 + assert vouchers[0].id == "voucher123" + + # Solana + registry_session = make_mock_aiohttp_session(MOCK_SOLANA_REGISTRY) + metadata_session = make_mock_aiohttp_session(MOCK_METADATA) + + with patch( + "aiohttp.ClientSession", side_effect=[registry_session, metadata_session] + ): + vouchers = await voucher_service.get_vouchers(MOCK_SOLANA_ADDRESS) + assert len(vouchers) == 1 + assert vouchers[0].id == "solticket123" diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index ef8b67ca..1221a9b0 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -7,6 +7,7 @@ Chain, ForgetMessage, InstanceMessage, + ItemHash, MessageType, Payment, PaymentType, @@ -14,7 +15,13 @@ ProgramMessage, StoreMessage, ) -from aleph_message.models.execution.environment import MachineResources +from aleph_message.models.execution.environment import ( + HostRequirements, + HypervisorType, + MachineResources, + NodeRequirements, + TrustedExecutionEnvironment, +) from aleph_message.status import MessageStatus from aleph.sdk.exceptions import InsufficientFundsError @@ -33,7 +40,7 @@ async def test_create_post(mock_session_with_post_success): sync=False, ) - assert mock_session_with_post_success.http_session.post.called_once + assert mock_session_with_post_success.http_session.post.assert_called_once assert isinstance(post_message, PostMessage) assert message_status == MessageStatus.PENDING @@ -47,7 +54,7 @@ async def test_create_aggregate(mock_session_with_post_success): channel="TEST", ) - assert mock_session_with_post_success.http_session.post.called_once + assert mock_session_with_post_success.http_session.post.assert_called_once assert isinstance(aggregate_message, AggregateMessage) @@ -83,7 +90,7 @@ async def test_create_store(mock_session_with_post_success): storage_engine=StorageEnum.storage, ) - assert mock_session_with_post_success.http_session.post.called + assert mock_session_with_post_success.http_session.post.assert_called assert isinstance(store_message, StoreMessage) @@ -98,7 +105,7 @@ async def test_create_program(mock_session_with_post_success): metadata={"tags": ["test"]}, ) - assert mock_session_with_post_success.http_session.post.called_once + assert mock_session_with_post_success.http_session.post.assert_called_once assert isinstance(program_message, ProgramMessage) @@ -108,7 +115,6 @@ async def test_create_instance(mock_session_with_post_success): instance_message, message_status = await session.create_instance( rootfs="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", rootfs_size=1, - rootfs_name="rootfs", channel="TEST", metadata={"tags": ["test"]}, payment=Payment( @@ -116,9 +122,10 @@ async def test_create_instance(mock_session_with_post_success): receiver="0x4145f182EF2F06b45E50468519C1B92C60FBd4A0", type=PaymentType.superfluid, ), + hypervisor=HypervisorType.qemu, ) - assert mock_session_with_post_success.http_session.post.called_once + assert mock_session_with_post_success.http_session.post.assert_called_once assert isinstance(instance_message, InstanceMessage) @@ -131,7 +138,6 @@ async def test_create_instance_no_payment(mock_session_with_post_success): instance_message, message_status = await session.create_instance( rootfs="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", rootfs_size=1, - rootfs_name="rootfs", channel="TEST", metadata={"tags": ["test"]}, payment=None, @@ -140,10 +146,63 @@ async def test_create_instance_no_payment(mock_session_with_post_success): assert instance_message.content.payment.type == PaymentType.hold assert instance_message.content.payment.chain == Chain.ETH - assert mock_session_with_post_success.http_session.post.called_once + assert mock_session_with_post_success.http_session.post.assert_called_once assert isinstance(instance_message, InstanceMessage) +@pytest.mark.asyncio +async def test_create_instance_no_hypervisor(mock_session_with_post_success): + """Test that an instance can be created with no hypervisor specified. + It should in this case default to "firecracker". + """ + async with mock_session_with_post_success as session: + instance_message, message_status = await session.create_instance( + rootfs="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + rootfs_size=1, + channel="TEST", + metadata={"tags": ["test"]}, + hypervisor=None, + ) + + assert instance_message.content.environment.hypervisor == HypervisorType.qemu + + assert mock_session_with_post_success.http_session.post.assert_called_once + assert isinstance(instance_message, InstanceMessage) + + +@pytest.mark.asyncio +async def test_create_confidential_instance(mock_session_with_post_success): + async with mock_session_with_post_success as session: + confidential_instance_message, message_status = await session.create_instance( + rootfs="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + rootfs_size=1, + channel="TEST", + metadata={"tags": ["test"]}, + payment=Payment( + chain=Chain.AVAX, + receiver="0x4145f182EF2F06b45E50468519C1B92C60FBd4A0", + type=PaymentType.superfluid, + ), + hypervisor=HypervisorType.qemu, + trusted_execution=TrustedExecutionEnvironment( + firmware=ItemHash( + "cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe" + ), + policy=0b1, + ), + requirements=HostRequirements( + node=NodeRequirements( + node_hash=ItemHash( + "cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe" + ), + ) + ), + ) + + assert mock_session_with_post_success.http_session.post.assert_called_once + assert isinstance(confidential_instance_message, InstanceMessage) + + @pytest.mark.asyncio async def test_forget(mock_session_with_post_success): async with mock_session_with_post_success as session: @@ -153,7 +212,7 @@ async def test_forget(mock_session_with_post_success): channel="TEST", ) - assert mock_session_with_post_success.http_session.post.called_once + assert mock_session_with_post_success.http_session.post.assert_called_once assert isinstance(forget_message, ForgetMessage) @@ -226,11 +285,80 @@ async def test_create_instance_insufficient_funds_error( await session.create_instance( rootfs="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", rootfs_size=1, - rootfs_name="rootfs", channel="TEST", metadata={"tags": ["test"]}, payment=Payment( chain=Chain.ETH, type=PaymentType.hold, + receiver=None, ), ) + + +@pytest.mark.asyncio +async def test_create_instance_with_credit_payment(mock_session_with_post_success): + """Test that an instance can be created with credit payment.""" + async with mock_session_with_post_success as session: + instance_message, message_status = await session.create_instance( + rootfs="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + rootfs_size=1, + channel="TEST", + metadata={"tags": ["test"]}, + payment=Payment( + chain=Chain.ETH, + receiver=None, + type=PaymentType.credit, + ), + ) + + assert instance_message.content.payment.type == PaymentType.credit + assert instance_message.content.payment.chain == Chain.ETH + assert instance_message.content.payment.receiver is None + + assert mock_session_with_post_success.http_session.post.assert_called_once + assert isinstance(instance_message, InstanceMessage) + + +@pytest.mark.asyncio +async def test_create_store_with_credit_payment(mock_session_with_post_success): + """Test that a store message can be created with credit payment.""" + mock_ipfs_push_file = AsyncMock() + mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + + mock_session_with_post_success.ipfs_push_file = mock_ipfs_push_file + + async with mock_session_with_post_success as session: + store_message, message_status = await session.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.ipfs, + payment=Payment( + chain=Chain.ETH, + receiver=None, + type=PaymentType.credit, + ), + ) + + assert store_message.content.payment.type == PaymentType.credit + assert store_message.content.payment.chain == Chain.ETH + assert isinstance(store_message, StoreMessage) + + +@pytest.mark.asyncio +async def test_create_store_default_payment(mock_session_with_post_success): + """Test that a store message defaults to hold payment on ETH.""" + mock_ipfs_push_file = AsyncMock() + mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + + mock_session_with_post_success.ipfs_push_file = mock_ipfs_push_file + + async with mock_session_with_post_success as session: + store_message, message_status = await session.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + assert store_message.content.payment.type == PaymentType.hold + assert store_message.content.payment.chain == Chain.ETH + assert isinstance(store_message, StoreMessage) diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index 7cfb38f3..674becf7 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -23,6 +23,20 @@ async def test_fetch_aggregate(): assert response.keys() == {"nodes", "resource_nodes"} +@pytest.mark.asyncio +async def test_get_aggregate(): + mock_session = make_mock_get_session( + {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} + ) + async with mock_session: + response = await mock_session.get_aggregate( + address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10", + key="corechannel", + ) + assert response is not None + assert response.keys() == {"nodes", "resource_nodes"} + + @pytest.mark.asyncio async def test_fetch_aggregates(): mock_session = make_mock_get_session( @@ -37,6 +51,21 @@ async def test_fetch_aggregates(): assert response["corechannel"].keys() == {"nodes", "resource_nodes"} +@pytest.mark.asyncio +async def test_get_aggregates(): + mock_session = make_mock_get_session( + {"data": {"corechannel": {"nodes": [], "resource_nodes": []}}} + ) + + async with mock_session: + response = await mock_session.get_aggregates( + address="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10" + ) + assert response is not None + assert response.keys() == {"corechannel"} + assert response["corechannel"].keys() == {"nodes", "resource_nodes"} + + @pytest.mark.asyncio async def test_get_posts(raw_posts_response): mock_session = make_mock_get_session(raw_posts_response(1)) diff --git a/tests/unit/test_balance.py b/tests/unit/test_balance.py new file mode 100644 index 00000000..793baa7d --- /dev/null +++ b/tests/unit/test_balance.py @@ -0,0 +1,39 @@ +from unittest.mock import patch + +import pytest + +from aleph.sdk.query.responses import BalanceResponse +from tests.unit.conftest import make_mock_get_session + + +@pytest.mark.asyncio +async def test_get_balances(): + """ + Test that the get_balances method returns the correct BalanceResponse + for a specific address when called on the AlephHttpClient. + """ + address = "0xd463495a6FEaC9921FD0C3a595B81E7B2C02B24d" + + balance_data = { + "address": address, + "balance": 351.25, + "details": {"ETH": 100.5, "SOL": 250.75}, + "locked_amount": 50.0, + "credit_balance": 1000, + } + + mock_client = make_mock_get_session(balance_data) + + expected_url = f"/api/v0/addresses/{address}/balance" + # Adding type assertion to handle None case + assert mock_client._http_session is not None + with patch.object( + mock_client._http_session, "get", wraps=mock_client._http_session.get + ) as spy: + async with mock_client: + response = await mock_client.get_balances(address) + + # Verify the response + assert isinstance(response, BalanceResponse) + # Verify the balances command calls the correct URL + spy.assert_called_once_with(expected_url, params=None) diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index 07b67602..0fbd717e 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -8,7 +8,12 @@ from nacl.signing import VerifyKey from aleph.sdk.chains.common import get_verification_buffer -from aleph.sdk.chains.sol import SOLAccount, get_fallback_account, verify_signature +from aleph.sdk.chains.solana import ( + SOLAccount, + get_fallback_account, + parse_private_key, + verify_signature, +) from aleph.sdk.exceptions import BadSignatureError @@ -136,3 +141,56 @@ async def test_sign_raw(solana_account): assert isinstance(signature, bytes) verify_signature(signature, solana_account.get_address(), buffer) + + +def test_parse_solana_private_key_bytes(): + # Valid 32-byte private key + private_key_bytes = bytes(range(32)) + parsed_key = parse_private_key(private_key_bytes) + assert isinstance(parsed_key, bytes) + assert len(parsed_key) == 32 + assert parsed_key == private_key_bytes + + # Invalid private key (too short) + with pytest.raises( + ValueError, match="The private key in bytes must be exactly 32 bytes long." + ): + parse_private_key(bytes(range(31))) + + +def test_parse_solana_private_key_base58(): + # Valid base58 private key (32 bytes) + base58_key = base58.b58encode(bytes(range(32))).decode("utf-8") + parsed_key = parse_private_key(base58_key) + assert isinstance(parsed_key, bytes) + assert len(parsed_key) == 32 + + # Invalid base58 key (not decodable) + with pytest.raises(ValueError, match="Invalid base58 encoded private key"): + parse_private_key("invalid_base58_key") + + # Invalid base58 key (wrong length) + with pytest.raises( + ValueError, + match="The base58 decoded private key must be either 32 or 64 bytes long.", + ): + parse_private_key(base58.b58encode(bytes(range(31))).decode("utf-8")) + + +def test_parse_solana_private_key_list(): + # Valid list of uint8 integers (64 elements, but we only take the first 32 for private key) + uint8_list = list(range(64)) + parsed_key = parse_private_key(uint8_list) + assert isinstance(parsed_key, bytes) + assert len(parsed_key) == 32 + assert parsed_key == bytes(range(32)) + + # Invalid list (contains non-integers) + with pytest.raises(ValueError, match="Invalid uint8 array"): + parse_private_key([1, 2, "not an int", 4]) # type: ignore # Ignore type check for string + + # Invalid list (less than 32 elements) + with pytest.raises( + ValueError, match="The uint8 array must contain at least 32 elements." + ): + parse_private_key(list(range(31))) diff --git a/tests/unit/test_chain_svm.py b/tests/unit/test_chain_svm.py new file mode 100644 index 00000000..ced673c3 --- /dev/null +++ b/tests/unit/test_chain_svm.py @@ -0,0 +1,199 @@ +import json +from dataclasses import asdict, dataclass +from pathlib import Path +from tempfile import NamedTemporaryFile + +import base58 +import pytest +from aleph_message.models import Chain +from nacl.signing import VerifyKey + +from aleph.sdk.chains.common import get_verification_buffer +from aleph.sdk.chains.solana import get_fallback_account as get_solana_account +from aleph.sdk.chains.solana import verify_signature +from aleph.sdk.chains.svm import SVMAccount +from aleph.sdk.exceptions import BadSignatureError + + +@dataclass +class Message: + chain: str + sender: str + type: str + item_hash: str + + +@pytest.fixture +def svm_account() -> SVMAccount: + with NamedTemporaryFile(delete=False) as private_key_file: + private_key_file.close() + solana_account = get_solana_account(path=Path(private_key_file.name)) + return SVMAccount(private_key=solana_account.private_key) + + +@pytest.fixture +def svm_eclipse_account() -> SVMAccount: + with NamedTemporaryFile(delete=False) as private_key_file: + private_key_file.close() + solana_account = get_solana_account(path=Path(private_key_file.name)) + return SVMAccount(private_key=solana_account.private_key, chain=Chain.ECLIPSE) + + +def test_svm_account_init(): + with NamedTemporaryFile() as private_key_file: + solana_account = get_solana_account(path=Path(private_key_file.name)) + account = SVMAccount(private_key=solana_account.private_key) + + # Default chain should be SOL + assert account.CHAIN == Chain.SOL + assert account.CURVE == "curve25519" + assert account._signing_key.verify_key + assert isinstance(account.private_key, bytes) + assert len(account.private_key) == 32 + + # Test with custom chain + account_eclipse = SVMAccount( + private_key=solana_account.private_key, chain=Chain.ECLIPSE + ) + assert account_eclipse.CHAIN == Chain.ECLIPSE + + +@pytest.mark.asyncio +async def test_svm_sign_message(svm_account): + message = asdict(Message("ES", svm_account.get_address(), "SomeType", "ItemHash")) + initial_message = message.copy() + await svm_account.sign_message(message) + assert message["signature"] + + address = message["sender"] + assert address + assert isinstance(address, str) + signature = json.loads(message["signature"]) + + pubkey = base58.b58decode(signature["publicKey"]) + assert isinstance(pubkey, bytes) + assert len(pubkey) == 32 + + verify_key = VerifyKey(pubkey) + verification_buffer = get_verification_buffer(message) + assert get_verification_buffer(initial_message) == verification_buffer + verif = verify_key.verify( + verification_buffer, signature=base58.b58decode(signature["signature"]) + ) + + assert verif == verification_buffer + assert message["sender"] == signature["publicKey"] + + pubkey = svm_account.get_public_key() + assert isinstance(pubkey, str) + assert len(pubkey) == 64 + + +@pytest.mark.asyncio +async def test_svm_custom_chain_sign_message(svm_eclipse_account): + message = asdict( + Message( + Chain.ECLIPSE, svm_eclipse_account.get_address(), "SomeType", "ItemHash" + ) + ) + await svm_eclipse_account.sign_message(message) + assert message["signature"] + + # Verify message has correct chain + assert message["chain"] == Chain.ECLIPSE + + # Rest of verification is the same + signature = json.loads(message["signature"]) + pubkey = base58.b58decode(signature["publicKey"]) + verify_key = VerifyKey(pubkey) + verification_buffer = get_verification_buffer(message) + verif = verify_key.verify( + verification_buffer, signature=base58.b58decode(signature["signature"]) + ) + assert verif == verification_buffer + + +@pytest.mark.asyncio +async def test_svm_decrypt(svm_account): + assert svm_account.CURVE == "curve25519" + content = b"SomeContent" + + encrypted = await svm_account.encrypt(content) + assert isinstance(encrypted, bytes) + decrypted = await svm_account.decrypt(encrypted) + assert isinstance(decrypted, bytes) + assert content == decrypted + + +@pytest.mark.asyncio +async def test_svm_verify_signature(svm_account): + message = asdict( + Message( + "SVM", + svm_account.get_address(), + "POST", + "SomeHash", + ) + ) + await svm_account.sign_message(message) + assert message["signature"] + raw_signature = json.loads(message["signature"])["signature"] + assert isinstance(raw_signature, str) + + verify_signature(raw_signature, message["sender"], get_verification_buffer(message)) + + # as bytes + verify_signature( + base58.b58decode(raw_signature), + base58.b58decode(message["sender"]), + get_verification_buffer(message).decode("utf-8"), + ) + + +@pytest.mark.asyncio +async def test_verify_signature_with_forged_signature(svm_account): + message = asdict( + Message( + "SVM", + svm_account.get_address(), + "POST", + "SomeHash", + ) + ) + await svm_account.sign_message(message) + assert message["signature"] + # create forged 64 bit signature from random bytes + forged = base58.b58encode(bytes(64)).decode("utf-8") + + with pytest.raises(BadSignatureError): + verify_signature(forged, message["sender"], get_verification_buffer(message)) + + +@pytest.mark.asyncio +async def test_svm_sign_raw(svm_account): + buffer = b"SomeBuffer" + signature = await svm_account.sign_raw(buffer) + assert signature + assert isinstance(signature, bytes) + + verify_signature(signature, svm_account.get_address(), buffer) + + +def test_svm_with_various_chain_values(): + # Test with different chain formats + with NamedTemporaryFile() as private_key_file: + solana_account = get_solana_account(path=Path(private_key_file.name)) + + # Test with string + account1 = SVMAccount(private_key=solana_account.private_key, chain="ES") + assert account1.CHAIN == Chain.ECLIPSE + + # Test with Chain enum if it exists + account2 = SVMAccount( + private_key=solana_account.private_key, chain=Chain.ECLIPSE + ) + assert account2.CHAIN == Chain.ECLIPSE + + # Test default + account3 = SVMAccount(private_key=solana_account.private_key) + assert account3.CHAIN == Chain.SOL diff --git a/tests/unit/test_chain_tezos.py b/tests/unit/test_chain_tezos.py index 0beaffc9..96e52ca3 100644 --- a/tests/unit/test_chain_tezos.py +++ b/tests/unit/test_chain_tezos.py @@ -31,7 +31,7 @@ async def test_tezos_account(tezos_account: TezosAccount): message = Message("TEZOS", tezos_account.get_address(), "SomeType", "ItemHash") signed = await tezos_account.sign_message(asdict(message)) assert signed["signature"] - assert len(signed["signature"]) == 188 + assert len(signed["signature"]) == 187 address = tezos_account.get_address() assert address is not None @@ -40,7 +40,7 @@ async def test_tezos_account(tezos_account: TezosAccount): pubkey = tezos_account.get_public_key() assert isinstance(pubkey, str) - assert len(pubkey) == 55 + assert len(pubkey) == 54 @pytest.mark.asyncio diff --git a/tests/unit/test_credits.py b/tests/unit/test_credits.py new file mode 100644 index 00000000..fb2e6d95 --- /dev/null +++ b/tests/unit/test_credits.py @@ -0,0 +1,75 @@ +from unittest.mock import patch + +import pytest + +from aleph.sdk.query.responses import CreditsHistoryResponse +from tests.unit.conftest import make_mock_get_session + + +@pytest.mark.asyncio +async def test_get_credits_history(): + """ + Test credits history commands + """ + address = "0xd463495a6FEaC9921FD0C3a595B81E7B2C02B24d" + + # Mock data for credit history + credit_history_data = { + "address": address, + "credit_history": [ + { + "amount": -22, + "ratio": None, + "tx_hash": None, + "token": None, + "chain": None, + "provider": "ALEPH", + "origin": None, + "origin_ref": "212f4825dd30e01f3801cdff1bdf8cd4d1c14ce2d31d695aee429d2ad0dfcba1", + "payment_method": "credit_expense", + "credit_ref": "cd77a7983af168941fd011427c6198b146ccd6f85077e0b593a4e7239d45fb11", + "credit_index": 0, + "expiration_date": None, + "message_timestamp": "2025-09-30T06:57:26.106000Z", + }, + { + "amount": -22, + "ratio": None, + "tx_hash": None, + "token": None, + "chain": None, + "provider": "ALEPH", + "origin": None, + "origin_ref": "36ceb85fb570fc87a6b906dc89df39129a971de96cbc56250553cfb8d49487e3", + "payment_method": "credit_expense", + "credit_ref": "5881c8f813ea186b25a9a20d9bea46e2082c4d61c2b9e7d53bf8a164dc892b73", + "credit_index": 0, + "expiration_date": None, + "message_timestamp": "2025-09-30T02:57:07.673000Z", + }, + ], + "pagination_page": 1, + "pagination_total": 1, + "pagination_per_page": 200, + "pagination_item": "credit_history", + } + + mock_client = make_mock_get_session(credit_history_data) + + # Test the method with a specific address + expected_url = f"/api/v0/addresses/{address}/credit_history" + # Adding type assertion to handle None case + assert mock_client._http_session is not None + with patch.object( + mock_client._http_session, "get", wraps=mock_client._http_session.get + ) as spy: + async with mock_client: + response = await mock_client.get_credit_history(address) + + # Verify the response + assert isinstance(response, CreditsHistoryResponse) + # Verify the credits history commands call the correct url + spy.assert_called_once_with( + expected_url, params={"page": "1", "pagination": "200"} + ) + assert len(response.credit_history) == 2 diff --git a/tests/unit/test_domains.py b/tests/unit/test_domains.py index 380e4bb5..eadfcec1 100644 --- a/tests/unit/test_domains.py +++ b/tests/unit/test_domains.py @@ -47,7 +47,7 @@ async def test_configured_domain(): url = "https://custom-domain-unit-test.aleph.sh" hostname = hostname_from_url(url) status = await alephdns.check_domain(hostname, TargetType.IPFS, "0xfakeaddress") - assert type(status) is dict + assert isinstance(status, dict) @pytest.mark.asyncio @@ -57,4 +57,4 @@ async def test_not_configured_domain(): hostname = hostname_from_url(url) with pytest.raises(DomainConfigurationError): status = await alephdns.check_domain(hostname, TargetType.IPFS, "0xfakeaddress") - assert type(status) is None + assert status is None diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index 377e6d41..a889949d 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -1,7 +1,19 @@ +import tempfile +from pathlib import Path + import pytest from aleph.sdk import AlephHttpClient -from aleph.sdk.conf import settings as sdk_settings + +from .conftest import make_mock_get_session + + +def make_mock_download_client(item_hash: str) -> AlephHttpClient: + if item_hash == "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH": + return make_mock_get_session(b"test\n") + if item_hash == "Qmdy5LaAL4eghxE7JD6Ah5o4PJGarjAV9st8az2k52i1vq": + return make_mock_get_session(bytes(5817703)) + raise NotImplementedError @pytest.mark.parametrize( @@ -13,10 +25,30 @@ ) @pytest.mark.asyncio async def test_download(file_hash: str, expected_size: int): - async with AlephHttpClient(api_server=sdk_settings.API_HOST) as client: - file_content = await client.download_file(file_hash) # File is 5B - file_size = len(file_content) - assert file_size == expected_size + mock_download_client = make_mock_download_client(file_hash) + async with mock_download_client: + file_content = await mock_download_client.download_file(file_hash) + file_size = len(file_content) + assert file_size == expected_size + + +@pytest.mark.asyncio +async def test_download_to_file(): + file_hash = "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH" + mock_download_client = make_mock_download_client(file_hash) + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + download_path = temp_dir_path / "test.txt" + + async with mock_download_client: + returned_path = await mock_download_client.download_file_to_path( + file_hash, str(download_path) + ) + + assert returned_path == download_path + assert download_path.is_file() + with open(download_path, "r") as file: + assert file.read().strip() == "test" @pytest.mark.parametrize( @@ -28,7 +60,8 @@ async def test_download(file_hash: str, expected_size: int): ) @pytest.mark.asyncio async def test_download_ipfs(file_hash: str, expected_size: int): - async with AlephHttpClient(api_server=sdk_settings.API_HOST) as client: - file_content = await client.download_file_ipfs(file_hash) # 5817703 B FILE - file_size = len(file_content) - assert file_size == expected_size + mock_download_client = make_mock_download_client(file_hash) + async with mock_download_client: + file_content = await mock_download_client.download_file_ipfs(file_hash) + file_size = len(file_content) + assert file_size == expected_size diff --git a/tests/unit/test_gas_estimation.py b/tests/unit/test_gas_estimation.py new file mode 100644 index 00000000..7db391ad --- /dev/null +++ b/tests/unit/test_gas_estimation.py @@ -0,0 +1,182 @@ +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest +from aleph_message.models import Chain +from web3.exceptions import ContractCustomError +from web3.types import TxParams + +from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.connectors.superfluid import Superfluid +from aleph.sdk.exceptions import InsufficientFundsError +from aleph.sdk.types import TokenType + + +@pytest.fixture +def mock_eth_account(): + private_key = b"\x01" * 32 + account = ETHAccount( + private_key, + chain=Chain.ETH, + ) + account._provider = MagicMock() + account._provider.eth = MagicMock() + account._provider.eth.gas_price = 20_000_000_000 # 20 Gwei + account._provider.eth.estimate_gas = MagicMock( + return_value=100_000 + ) # 100k gas units + + # Mock get_eth_balance to return a specific balance + with patch.object(account, "get_eth_balance", return_value=10**18): # 1 ETH + yield account + + +@pytest.fixture +def mock_superfluid(mock_eth_account): + superfluid = Superfluid(mock_eth_account) + superfluid.cfaV1Instance = MagicMock() + superfluid.cfaV1Instance.create_flow = MagicMock() + superfluid.super_token = "0x0000000000000000000000000000000000000000" + superfluid.normalized_address = "0x0000000000000000000000000000000000000000" + + # Mock the operation + operation = MagicMock() + operation._get_populated_transaction_request = MagicMock( + return_value={"value": 0, "gas": 100000, "gasPrice": 20_000_000_000} + ) + superfluid.cfaV1Instance.create_flow.return_value = operation + + return superfluid + + +class TestGasEstimation: + def test_can_transact_with_sufficient_funds(self, mock_eth_account): + tx = TxParams({"to": "0xreceiver", "value": 0}) + + # Should pass with 1 ETH balance against ~0.002 ETH gas cost + assert mock_eth_account.can_transact(tx=tx, block=True) is True + + def test_can_transact_with_insufficient_funds(self, mock_eth_account): + tx = TxParams({"to": "0xreceiver", "value": 0}) + + # Set balance to almost zero + with patch.object(mock_eth_account, "get_eth_balance", return_value=1000): + # Should raise InsufficientFundsError + with pytest.raises(InsufficientFundsError) as exc_info: + mock_eth_account.can_transact(tx=tx, block=True) + + assert exc_info.value.token_type == TokenType.GAS + + def test_can_transact_with_legacy_gas_price(self, mock_eth_account): + tx = TxParams( + {"to": "0xreceiver", "value": 0, "gasPrice": 30_000_000_000} # 30 Gwei + ) + + # Should use the tx's gasPrice instead of default + mock_eth_account.can_transact(tx=tx, block=True) + + # It should have used the tx's gasPrice for calculation + mock_eth_account._provider.eth.estimate_gas.assert_called_once() + + def test_can_transact_with_eip1559_gas(self, mock_eth_account): + tx = TxParams( + {"to": "0xreceiver", "value": 0, "maxFeePerGas": 40_000_000_000} # 40 Gwei + ) + + # Should use the tx's maxFeePerGas + mock_eth_account.can_transact(tx=tx, block=True) + + # It should have used the tx's maxFeePerGas for calculation + mock_eth_account._provider.eth.estimate_gas.assert_called_once() + + def test_can_transact_with_contract_error(self, mock_eth_account): + tx = TxParams({"to": "0xreceiver", "value": 0}) + + # Make estimate_gas throw a ContractCustomError + mock_eth_account._provider.eth.estimate_gas.side_effect = ContractCustomError( + "error" + ) + + # Should fallback to MIN_ETH_BALANCE_WEI + mock_eth_account.can_transact(tx=tx, block=True) + + # It should have called estimate_gas + mock_eth_account._provider.eth.estimate_gas.assert_called_once() + + +class TestSuperfluidFlowEstimation: + @pytest.mark.asyncio + async def test_simulate_create_tx_flow_success( + self, mock_superfluid, mock_eth_account + ): + # Patch both the _get_populated_transaction_request and can_transact methods + mock_tx = {"value": 0, "gas": 100000, "gasPrice": 20_000_000_000} + with patch.object( + mock_superfluid, "_get_populated_transaction_request", return_value=mock_tx + ): + with patch.object(mock_eth_account, "can_transact", return_value=True): + result = mock_superfluid._simulate_create_tx_flow(Decimal("0.00000005")) + assert result is True + + # Verify the flow was correctly simulated but not executed + mock_superfluid.cfaV1Instance.create_flow.assert_called_once() + assert "0x0000000000000000000000000000000000000001" in str( + mock_superfluid.cfaV1Instance.create_flow.call_args + ) + + @pytest.mark.asyncio + async def test_simulate_create_tx_flow_contract_error( + self, mock_superfluid, mock_eth_account + ): + # Setup a contract error code for insufficient deposit + error = ContractCustomError("Insufficient deposit") + error.data = "0xea76c9b3" # This is the specific error code checked in the code + + # Mock _get_populated_transaction_request and can_transact + mock_tx = {"value": 0, "gas": 100000, "gasPrice": 20_000_000_000} + with patch.object( + mock_superfluid, "_get_populated_transaction_request", return_value=mock_tx + ): + # Mock can_transact to throw the error + with patch.object(mock_eth_account, "can_transact", side_effect=error): + # Also mock get_super_token_balance for the error case + with patch.object( + mock_eth_account, "get_super_token_balance", return_value=0 + ): + # Should raise InsufficientFundsError for ALEPH token + with pytest.raises(InsufficientFundsError) as exc_info: + mock_superfluid._simulate_create_tx_flow(Decimal("0.00000005")) + + assert exc_info.value.token_type == TokenType.ALEPH + + @pytest.mark.asyncio + async def test_simulate_create_tx_flow_other_error( + self, mock_superfluid, mock_eth_account + ): + # Setup a different contract error code + error = ContractCustomError("Other error") + error.data = "0xsomeothercode" + + # Mock _get_populated_transaction_request and can_transact + mock_tx = {"value": 0, "gas": 100000, "gasPrice": 20_000_000_000} + with patch.object( + mock_superfluid, "_get_populated_transaction_request", return_value=mock_tx + ): + # Mock can_transact to throw the error + with patch.object(mock_eth_account, "can_transact", side_effect=error): + # Should return False for other errors + result = mock_superfluid._simulate_create_tx_flow(Decimal("0.00000005")) + assert result is False + + @pytest.mark.asyncio + async def test_can_start_flow_uses_simulation(self, mock_superfluid): + # Mock _simulate_create_tx_flow to verify it's called + with patch.object( + mock_superfluid, "_simulate_create_tx_flow", return_value=True + ) as mock_simulate: + result = mock_superfluid.can_start_flow(Decimal("0.00000005")) + + assert result is True + mock_simulate.assert_called_once_with( + flow=Decimal("0.00000005"), block=True + ) diff --git a/tests/unit/test_price.py b/tests/unit/test_price.py new file mode 100644 index 00000000..f2759193 --- /dev/null +++ b/tests/unit/test_price.py @@ -0,0 +1,85 @@ +from decimal import Decimal + +import pytest + +from aleph.sdk.exceptions import InvalidHashError +from aleph.sdk.query.responses import PriceResponse +from tests.unit.conftest import make_mock_get_session, make_mock_get_session_400 + + +@pytest.mark.asyncio +async def test_get_program_price_valid(): + """ + Test that the get_program_price method returns the correct PriceResponse + when given a valid item hash. + """ + expected = PriceResponse( + required_tokens=3.0555555555555556e-06, + payment_type="superfluid", + ) + mock_session = make_mock_get_session(expected.model_dump()) + async with mock_session: + response = await mock_session.get_program_price("cacacacacacaca") + assert response == expected + + +@pytest.mark.asyncio +async def test_get_program_price_cost_and_required_token(): + """ + Test that the get_program_price method returns the correct PriceResponse + when + 1 ) cost & required_token is here (priority to cost) who is a string that convert to decimal + 2 ) When only required_token is here who is a float that now would be to be convert to decimal + """ + # Case 1 + expected = { + "required_tokens": 0.001527777777777778, + "cost": "0.001527777777777777", + "payment_type": "credit", + } + + # Case 2 + expected_old = { + "required_tokens": 0.001527777777777778, + "payment_type": "credit", + } + + # Expected model using the cost field as the source of truth + expected_model = PriceResponse( + required_tokens=Decimal("0.001527777777777778"), + cost=expected["cost"], + payment_type=expected["payment_type"], + ) + + # Expected model for the old format + expected_model_old = PriceResponse( + required_tokens=Decimal(str(expected_old["required_tokens"])), + payment_type=expected_old["payment_type"], + ) + + mock_session = make_mock_get_session(expected) + mock_session_old = make_mock_get_session(expected_old) + + async with mock_session: + response = await mock_session.get_program_price("cacacacacacaca") + assert str(response.required_tokens) == str(expected_model.required_tokens) + assert response.cost == expected_model.cost + assert response.payment_type == expected_model.payment_type + + async with mock_session_old: + response = await mock_session_old.get_program_price("cacacacacacaca") + assert str(response.required_tokens) == str(expected_model_old.required_tokens) + assert response.cost == expected_model_old.cost + assert response.payment_type == expected_model_old.payment_type + + +@pytest.mark.asyncio +async def test_get_program_price_invalid(): + """ + Test that the get_program_price method raises an InvalidHashError + when given an invalid item hash. + """ + mock_session = make_mock_get_session_400({"error": "Invalid hash"}) + async with mock_session: + with pytest.raises(InvalidHashError): + await mock_session.get_program_price("invalid_item_hash") diff --git a/tests/unit/test_remote_account.py b/tests/unit/test_remote_account.py index cb4a2af5..3abe979e 100644 --- a/tests/unit/test_remote_account.py +++ b/tests/unit/test_remote_account.py @@ -22,7 +22,7 @@ async def test_remote_storage(): curve="secp256k1", address=local_account.get_address(), public_key=local_account.get_public_key(), - ).dict() + ).model_dump() ) remote_account = await RemoteAccount.from_crypto_host( diff --git a/tests/unit/test_services.py b/tests/unit/test_services.py new file mode 100644 index 00000000..762fceea --- /dev/null +++ b/tests/unit/test_services.py @@ -0,0 +1,445 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient +from aleph.sdk.client.services.authenticated_port_forwarder import ( + AuthenticatedPortForwarder, + PortForwarder, +) +from aleph.sdk.client.services.crn import Crn +from aleph.sdk.client.services.dns import DNS +from aleph.sdk.client.services.instance import Instance +from aleph.sdk.client.services.scheduler import Scheduler +from aleph.sdk.types import ( + IPV4, + AllocationItem, + Dns, + PortFlags, + Ports, + SchedulerNodes, + SchedulerPlan, +) + + +@pytest.mark.asyncio +async def test_aleph_http_client_services_loading(): + """Test that services are properly loaded in AlephHttpClient's __aenter__""" + with patch("aiohttp.ClientSession") as mock_session: + mock_session_instance = AsyncMock() + mock_session.return_value = mock_session_instance + + client = AlephHttpClient(api_server="http://localhost") + + async def mocked_aenter(): + client._http_session = mock_session_instance + client.dns = DNS(client) + client.port_forwarder = PortForwarder(client) + client.crn = Crn(client) + client.scheduler = Scheduler(client) + client.instance = Instance(client) + return client + + with patch.object(client, "__aenter__", mocked_aenter), patch.object( + client, "__aexit__", AsyncMock() + ): + async with client: + assert isinstance(client.dns, DNS) + assert isinstance(client.port_forwarder, PortForwarder) + assert isinstance(client.crn, Crn) + assert isinstance(client.scheduler, Scheduler) + assert isinstance(client.instance, Instance) + + assert client.dns._client == client + assert client.port_forwarder._client == client + assert client.crn._client == client + assert client.scheduler._client == client + assert client.instance._client == client + + +@pytest.mark.asyncio +async def test_authenticated_http_client_services_loading(ethereum_account): + """Test that authenticated services are properly loaded in AuthenticatedAlephHttpClient's __aenter__""" + with patch("aiohttp.ClientSession") as mock_session: + mock_session_instance = AsyncMock() + mock_session.return_value = mock_session_instance + + client = AuthenticatedAlephHttpClient( + account=ethereum_account, api_server="http://localhost" + ) + + async def mocked_aenter(): + client._http_session = mock_session_instance + client.dns = DNS(client) + client.port_forwarder = AuthenticatedPortForwarder(client) + client.crn = Crn(client) + client.scheduler = Scheduler(client) + client.instance = Instance(client) + return client + + with patch.object(client, "__aenter__", mocked_aenter), patch.object( + client, "__aexit__", AsyncMock() + ): + async with client: + assert isinstance(client.dns, DNS) + assert isinstance(client.port_forwarder, AuthenticatedPortForwarder) + assert isinstance(client.crn, Crn) + assert isinstance(client.scheduler, Scheduler) + assert isinstance(client.instance, Instance) + + assert client.dns._client == client + assert client.port_forwarder._client == client + assert client.crn._client == client + assert client.scheduler._client == client + assert client.instance._client == client + + +def mock_aiohttp_session(response_data, raise_error=False, error_status=404): + """ + Creates a mock for aiohttp.ClientSession that properly handles async context managers. + + Args: + response_data: The data to return from the response's json() method + raise_error: Whether to raise an aiohttp.ClientResponseError + error_status: The HTTP status code to use if raising an error + + Returns: + A tuple of (patch_target, mock_session_context, mock_session, mock_response) + """ + # Mock the response object + mock_response = MagicMock() + + if raise_error: + # Set up raise_for_status to raise an exception + error = aiohttp.ClientResponseError( + request_info=MagicMock(), + history=tuple(), + status=error_status, + message="Not Found" if error_status == 404 else "Error", + ) + mock_response.raise_for_status = MagicMock(side_effect=error) + else: + # Normal case - just return the data + mock_response.raise_for_status = MagicMock() + mock_response.json = AsyncMock(return_value=response_data) + + # Mock the context manager for session.get + mock_context_manager = MagicMock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_response) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + + # Mock the session's get method to return our context manager + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_context_manager) + mock_session.post = MagicMock(return_value=mock_context_manager) + + # Mock the ClientSession context manager + mock_session_context = MagicMock() + mock_session_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_context.__aexit__ = AsyncMock(return_value=None) + + return "aiohttp.ClientSession", mock_session_context, mock_session, mock_response + + +@pytest.mark.asyncio +async def test_authenticated_port_forwarder_create_port_forward(ethereum_account): + """Test the create_port method in AuthenticatedPortForwarder""" + mock_client = MagicMock() + mock_client.http_session = AsyncMock() + mock_client.account = ethereum_account + + auth_port_forwarder = AuthenticatedPortForwarder(mock_client) + + ports = Ports(ports={80: PortFlags(tcp=True, udp=False)}) + + mock_message = MagicMock() + mock_status = MagicMock() + + # Setup the mock for create_aggregate + mock_client.create_aggregate = AsyncMock(return_value=(mock_message, mock_status)) + + # Mock the _verify_status_processed_and_ownership method + with patch.object( + auth_port_forwarder, + "_verify_status_processed_and_ownership", + AsyncMock(return_value=(mock_message, mock_status)), + ): + # Call the actual method + result_message, result_status = await auth_port_forwarder.create_ports( + item_hash="test_hash", ports=ports + ) + + # Verify create_aggregate was called + mock_client.create_aggregate.assert_called_once() + + # Check the parameters passed to create_aggregate + call_args = mock_client.create_aggregate.call_args + assert call_args[1]["key"] == "port-forwarding" + assert "test_hash" in call_args[1]["content"] + + # Verify the method returns what create_aggregate returns + assert result_message == mock_message + assert result_status == mock_status + + +@pytest.mark.asyncio +async def test_authenticated_port_forwarder_update_port(ethereum_account): + """Test the update_port method in AuthenticatedPortForwarder""" + mock_client = MagicMock() + mock_client.http_session = AsyncMock() + mock_client.account = ethereum_account + + auth_port_forwarder = AuthenticatedPortForwarder(mock_client) + + ports = Ports(ports={80: PortFlags(tcp=True, udp=False)}) + + mock_message = MagicMock() + mock_status = MagicMock() + + # Setup the mock for create_aggregate + mock_client.create_aggregate = AsyncMock(return_value=(mock_message, mock_status)) + + # Mock the _verify_status_processed_and_ownership method + with patch.object( + auth_port_forwarder, + "_verify_status_processed_and_ownership", + AsyncMock(return_value=(mock_message, mock_status)), + ): + # Call the actual method + result_message, result_status = await auth_port_forwarder.update_ports( + item_hash="test_hash", ports=ports + ) + + # Verify create_aggregate was called + mock_client.create_aggregate.assert_called_once() + + # Check the parameters passed to create_aggregate + call_args = mock_client.create_aggregate.call_args + assert call_args[1]["key"] == "port-forwarding" + assert "test_hash" in call_args[1]["content"] + + # Verify the method returns what create_aggregate returns + assert result_message == mock_message + assert result_status == mock_status + + +@pytest.mark.asyncio +async def test_dns_service_get_public_dns(): + """Test the DNSService get_public_dns method""" + mock_client = MagicMock() + dns_service = DNS(mock_client) + + # Mock the DnsListAdapter with a valid 64-character hash for ItemHash + mock_dns_list = [ + Dns( + name="test.aleph.sh", + item_hash="b236db23bf5ad005ad7f5d82eed08a68a925020f0755b2a59c03f784499198eb", + ipv6="2001:db8::1", + ipv4=IPV4(public="192.0.2.1", local="10.0.0.1"), + ) + ] + + # Patch DnsListAdapter.validate_json to return our mock DNS list + with patch( + "aleph.sdk.types.DnsListAdapter.validate_json", return_value=mock_dns_list + ): + # Set up mock for aiohttp.ClientSession to return a string (which is what validate_json expects) + patch_target, mock_session_context, _, _ = mock_aiohttp_session( + '["dummy json string"]' + ) + + # Patch the ClientSession constructor + with patch(patch_target, return_value=mock_session_context): + result = await dns_service.get_public_dns() + + assert len(result) == 1 + assert result[0].name == "test.aleph.sh" + assert ( + result[0].item_hash + == "b236db23bf5ad005ad7f5d82eed08a68a925020f0755b2a59c03f784499198eb" + ) + assert result[0].ipv6 == "2001:db8::1" + assert result[0].ipv4 is not None and result[0].ipv4.public == "192.0.2.1" + + +@pytest.mark.asyncio +async def test_crn_service_get_last_crn_version(): + """Test the CrnService get_last_crn_version method""" + mock_client = MagicMock() + crn_service = Crn(mock_client) + + # Set up mock for aiohttp.ClientSession + patch_target, mock_session_context, _, _ = mock_aiohttp_session( + {"tag_name": "v1.2.3"} + ) + + # Patch the ClientSession constructor + with patch(patch_target, return_value=mock_session_context): + result = await crn_service.get_last_crn_version() + assert result == "v1.2.3" + + +@pytest.mark.asyncio +async def test_scheduler_service_get_plan(): + """Test the SchedulerService get_plan method""" + mock_client = MagicMock() + scheduler_service = Scheduler(mock_client) + + mock_plan_data = { + "period": {"start_timestamp": "2023-01-01T00:00:00Z", "duration_seconds": 3600}, + "plan": { + "node1": { + "persistent_vms": [ + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + "fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210", + ], + "instances": [], + "on_demand_vms": [], + "jobs": [], + } + }, + } + + # Set up mock for aiohttp.ClientSession + patch_target, mock_session_context, _, _ = mock_aiohttp_session(mock_plan_data) + + # Patch the ClientSession constructor + with patch(patch_target, return_value=mock_session_context): + result = await scheduler_service.get_plan() + assert isinstance(result, SchedulerPlan) + assert "node1" in result.plan + assert ( + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + in result.plan["node1"].persistent_vms + ) + + +@pytest.mark.asyncio +async def test_scheduler_service_get_scheduler_node(): + """Test the SchedulerService get_scheduler_node method""" + mock_client = MagicMock() + scheduler_service = Scheduler(mock_client) + + mock_nodes_data = { + "nodes": [ + { + "node_id": "node1", + "url": "https://node1.aleph.im", + "ipv6": "2001:db8::1", + "supports_ipv6": True, + }, + { + "node_id": "node2", + "url": "https://node2.aleph.im", + "ipv6": None, + "supports_ipv6": False, + }, + ] + } + + # Set up mock for aiohttp.ClientSession + patch_target, mock_session_context, _, _ = mock_aiohttp_session(mock_nodes_data) + + # Patch the ClientSession constructor + with patch(patch_target, return_value=mock_session_context): + result = await scheduler_service.get_nodes() + assert isinstance(result, SchedulerNodes) + assert len(result.nodes) == 2 + assert result.nodes[0].node_id == "node1" + assert result.nodes[1].url == "https://node2.aleph.im" + + +@pytest.mark.asyncio +async def test_scheduler_service_get_allocation(): + """Test the SchedulerService get_allocation method""" + mock_client = MagicMock() + scheduler_service = Scheduler(mock_client) + + mock_allocation_data = { + "vm_hash": "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + "vm_type": "instance", + "vm_ipv6": "2001:db8::1", + "period": {"start_timestamp": "2023-01-01T00:00:00Z", "duration_seconds": 3600}, + "node": { + "node_id": "node1", + "url": "https://node1.aleph.im", + "ipv6": "2001:db8::1", + "supports_ipv6": True, + }, + } + + # Set up mock for aiohttp.ClientSession + patch_target, mock_session_context, _, _ = mock_aiohttp_session( + mock_allocation_data + ) + + # Patch the ClientSession constructor + with patch(patch_target, return_value=mock_session_context): + result = await scheduler_service.get_allocation( + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + ) + assert isinstance(result, AllocationItem) + assert ( + result.vm_hash + == "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + ) + assert result.node.node_id == "node1" + + +@pytest.mark.asyncio +async def test_utils_service_get_name_of_executable(): + """Test the UtilsService get_name_of_executable method""" + mock_client = MagicMock() + utils_service = Instance(mock_client) + + # Mock a message with metadata.name + mock_message = MagicMock() + mock_message.content.metadata = {"name": "test-executable"} + + # Set up the client mock to return the message + mock_client.get_message = AsyncMock(return_value=mock_message) + + # Test successful case + result = await utils_service.get_name_of_executable("hash1") + assert result == "test-executable" + + # Test with dict response + mock_client.get_message = AsyncMock( + return_value={"content": {"metadata": {"name": "dict-executable"}}} + ) + + result = await utils_service.get_name_of_executable("hash2") + assert result == "dict-executable" + + # Test with exception + mock_client.get_message = AsyncMock(side_effect=Exception("Test exception")) + + result = await utils_service.get_name_of_executable("hash3") + assert result is None + + +@pytest.mark.asyncio +async def test_utils_service_get_instances(): + """Test the UtilsService get_instances method""" + mock_client = MagicMock() + utils_service = Instance(mock_client) + + # Mock messages response + mock_messages = [MagicMock(), MagicMock()] + mock_response = MagicMock() + mock_response.messages = mock_messages + + # Set up the client mock + mock_client.get_messages = AsyncMock(return_value=mock_response) + + result = await utils_service.get_instances("0xaddress") + + # Check that get_messages was called with correct parameters + mock_client.get_messages.assert_called_once() + call_args = mock_client.get_messages.call_args[1] + assert call_args["page_size"] == 100 + assert call_args["message_filter"].addresses == ["0xaddress"] + + # Check result + assert result == mock_messages diff --git a/tests/unit/test_superfluid.py b/tests/unit/test_superfluid.py new file mode 100644 index 00000000..74bcc38e --- /dev/null +++ b/tests/unit/test_superfluid.py @@ -0,0 +1,113 @@ +import random +from decimal import Decimal +from unittest.mock import AsyncMock, patch + +import pytest +from aleph_message.models import Chain +from eth_utils import to_checksum_address + +from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.evm_utils import FlowUpdate + + +def generate_fake_eth_address(): + return to_checksum_address( + "0x" + "".join([random.choice("0123456789abcdef") for _ in range(40)]) + ) + + +@pytest.fixture +def mock_superfluid(): + with patch("aleph.sdk.connectors.superfluid.Superfluid") as MockSuperfluid: + mock_superfluid = MockSuperfluid.return_value + + # Mock methods for the Superfluid connector + mock_superfluid.create_flow = AsyncMock(return_value="0xTransactionHash") + mock_superfluid.delete_flow = AsyncMock(return_value="0xTransactionHash") + mock_superfluid.update_flow = AsyncMock(return_value="0xTransactionHash") + mock_superfluid.manage_flow = AsyncMock(return_value="0xTransactionHash") + + # Mock get_flow to return a mock Web3FlowInfo + mock_flow_info = {"timestamp": 0, "flowRate": 0, "deposit": 0, "owedDeposit": 0} + mock_superfluid.get_flow = AsyncMock(return_value=mock_flow_info) + + yield mock_superfluid + + +@pytest.fixture +def eth_account(mock_superfluid): + private_key = b"\x01" * 32 + account = ETHAccount( + private_key, + chain=Chain.AVAX, + ) + with patch.object( + account, "get_super_token_balance", new_callable=AsyncMock + ) as mock_get_balance: + mock_get_balance.return_value = Decimal("1") + with patch.object( + account, "can_transact", new_callable=AsyncMock + ) as mock_can_transact: + mock_can_transact.return_value = True + account.superfluid_connector = mock_superfluid + yield account + + +@pytest.mark.asyncio +async def test_initialization(eth_account): + assert eth_account.superfluid_connector is not None + + +@pytest.mark.asyncio +async def test_create_flow(eth_account, mock_superfluid): + receiver = generate_fake_eth_address() + flow = Decimal("0.00000005") + + tx_hash = await eth_account.create_flow(receiver, flow) + + assert tx_hash == "0xTransactionHash" + mock_superfluid.create_flow.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_delete_flow(eth_account, mock_superfluid): + receiver = generate_fake_eth_address() + + tx_hash = await eth_account.delete_flow(receiver) + + assert tx_hash == "0xTransactionHash" + mock_superfluid.delete_flow.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_update_flow(eth_account, mock_superfluid): + receiver = generate_fake_eth_address() + flow = Decimal("0.005") + + tx_hash = await eth_account.update_flow(receiver, flow) + + assert tx_hash == "0xTransactionHash" + mock_superfluid.update_flow.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_flow(eth_account, mock_superfluid): + receiver = generate_fake_eth_address() + + flow_info = await eth_account.get_flow(receiver) + + assert flow_info["timestamp"] == 0 + assert flow_info["flowRate"] == 0 + assert flow_info["deposit"] == 0 + assert flow_info["owedDeposit"] == 0 + + +@pytest.mark.asyncio +async def test_manage_flow(eth_account, mock_superfluid): + receiver = generate_fake_eth_address() + flow = Decimal("0.005") + + tx_hash = await eth_account.manage_flow(receiver, flow, FlowUpdate.INCREASE) + + assert tx_hash == "0xTransactionHash" + mock_superfluid.manage_flow.assert_awaited_once() diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 85f274e6..4ceb5a3f 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,6 @@ +import base64 import datetime +from unittest.mock import MagicMock import pytest as pytest from aleph_message.models import ( @@ -11,14 +13,20 @@ ProgramMessage, StoreMessage, ) -from aleph_message.models.execution.environment import MachineResources from aleph_message.models.execution.volume import ( EphemeralVolume, ImmutableVolume, PersistentVolume, ) -from aleph.sdk.utils import enum_as_str, get_message_type_value, parse_volume +from aleph.sdk.types import SEVInfo +from aleph.sdk.utils import ( + calculate_firmware_hash, + compute_confidential_measure, + enum_as_str, + get_message_type_value, + parse_volume, +) def test_get_message_type_value(): @@ -107,15 +115,16 @@ def test_enum_as_str(): ( MessageType.aggregate, { + "address": "0x1", "content": { - "Hello": MachineResources( - vcpus=1, - memory=1024, - seconds=1, - ) + "Hello": { + "vcpus": 1, + "memory": 1024, + "seconds": 1, + "published_ports": None, + }, }, "key": "test", - "address": "0x1", "time": 1.0, }, ), @@ -132,7 +141,7 @@ async def test_prepare_aleph_message( channel="TEST", ) - assert message.content.dict() == content + assert message.content.model_dump() == content def test_parse_immutable_volume(): @@ -150,6 +159,7 @@ def test_parse_immutable_volume(): def test_parse_ephemeral_volume(): volume_dict = { "comment": "Dummy hash", + "mount": "/opt/data", "ephemeral": True, "size_mib": 1, } @@ -161,6 +171,8 @@ def test_parse_ephemeral_volume(): def test_parse_persistent_volume(): volume_dict = { + "comment": "Dummy hash", + "mount": "/opt/data", "parent": { "ref": "QmX8K1c22WmQBAww5ShWQqwMiFif7XFrJD6iFBj7skQZXW", "use_latest": True, @@ -174,3 +186,56 @@ def test_parse_persistent_volume(): volume = parse_volume(volume) assert volume assert isinstance(volume, PersistentVolume) + + +def test_calculate_firmware_hash(): + mock_path = MagicMock( + read_bytes=MagicMock(return_value=b"abc"), + ) + + assert ( + calculate_firmware_hash(mock_path) + == "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + ) + + +def test_compute_confidential_measure(): + """Verify that we properly calculate the measurement we use agains the server + + Validated against the sevctl command: + $ RUST_LOG=trace sevctl measurement build --api-major 01 --api-minor 55 --build-id 24 --policy 1 + --tik ~/pycharm-aleph-sdk-python/decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca_tik.bin + --firmware /usr/share/ovmf/OVMF.fd --nonce URQNqJAqh/2ep4drjx/XvA + + [2024-07-05T11:19:06Z DEBUG sevctl::measurement] firmware + table len=4194304 sha256: d06471f485c0a61aba5a431ec136b947be56907acf6ed96afb11788ae4525aeb + [2024-07-05T11:19:06Z DEBUG sevctl::measurement] --tik base64: npOTEc4mtRGfXfB+G6EBdw== + [2024-07-05T11:19:06Z DEBUG sevctl::measurement] --nonce base64: URQNqJAqh/2ep4drjx/XvA== + [2024-07-05T11:19:06Z DEBUG sevctl::measurement] Raw measurement: BAE3GAEAAADQZHH0hcCmGrpaQx7BNrlHvlaQes9u2Wr7EXiK5FJa61EUDaiQKof9nqeHa48f17w= + [2024-07-05T11:19:06Z DEBUG sevctl::measurement] Signed measurement: ls2jv10V3HVShVI/RHCo/a43WO0soLZf0huU9ZZstIw= + [2024-07-05T11:19:06Z DEBUG sevctl::measurement] Measurement + nonce: ls2jv10V3HVShVI/RHCo/a43WO0soLZf0huU9ZZstIxRFA2okCqH/Z6nh2uPH9e8 + """ + + tik = bytes.fromhex("9e939311ce26b5119f5df07e1ba10177") + assert base64.b64encode(tik) == b"npOTEc4mtRGfXfB+G6EBdw==" + expected_hash = "d06471f485c0a61aba5a431ec136b947be56907acf6ed96afb11788ae4525aeb" + nonce = base64.b64decode("URQNqJAqh/2ep4drjx/XvA==") + sev_info = SEVInfo.model_validate( + { + "enabled": True, + "api_major": 1, + "api_minor": 55, + "build_id": 24, + "policy": 1, + "state": "running", + "handle": 1, + } + ) + + assert ( + base64.b64encode( + compute_confidential_measure( + sev_info, tik, expected_hash, nonce=nonce + ).digest() + ) + == b"ls2jv10V3HVShVI/RHCo/a43WO0soLZf0huU9ZZstIw=" + ) diff --git a/tests/unit/test_vm_client.py b/tests/unit/test_vm_client.py new file mode 100644 index 00000000..d9a9a36b --- /dev/null +++ b/tests/unit/test_vm_client.py @@ -0,0 +1,297 @@ +from urllib.parse import urlparse + +import aiohttp +import pytest +from aiohttp import web +from aioresponses import aioresponses +from aleph_message.models import ItemHash +from yarl import URL + +from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.client.vm_client import VmClient + +from .aleph_vm_authentication import ( + SignedOperation, + SignedPubKeyHeader, + authenticate_jwk, + authenticate_websocket_message, + verify_signed_operation, +) + + +@pytest.mark.asyncio +async def test_notify_allocation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post("http://localhost/control/allocation/notify", status=200) + await vm_client.notify_allocation(vm_id=vm_id) + assert len(m.requests) == 1 + assert ("POST", URL("http://localhost/control/allocation/notify")) in m.requests + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_perform_operation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "reboot" + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/{operation}", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.perform_operation(vm_id, operation) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_stop_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/stop", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.stop_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_reboot_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/reboot", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.reboot_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_erase_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/erase", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.erase_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_expire_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + with aioresponses() as m: + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/expire", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.expire_instance(vm_id) + assert status == 200 + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_get_logs(aiohttp_client): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + async def websocket_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + await ws.send_str("mock_log_entry") + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + return ws + + app = web.Application() + app.router.add_route( + "GET", "/control/machine/{vm_id}/stream_logs", websocket_handler + ) # Update route to match the URL + + client = await aiohttp_client(app) + + node_url = str(client.make_url("")).rstrip("/") + + vm_client = VmClient( + account=account, + node_url=node_url, + session=client.session, + ) + + logs = [] + async for log in vm_client.get_logs(vm_id): + logs.append(log) + if log == "mock_log_entry": + break + + assert logs == ["mock_log_entry"] + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_authenticate_jwk(aiohttp_client): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + async def test_authenticate_route(request): + address = await authenticate_jwk( + request, domain_name=urlparse(node_url).hostname + ) + assert vm_client.account.get_address() == address + return web.Response(text="ok") + + app = web.Application() + app.router.add_route( + "POST", f"/control/machine/{vm_id}/stop", test_authenticate_route + ) # Update route to match the URL + + client = await aiohttp_client(app) + + node_url = str(client.make_url("")).rstrip("/") + + vm_client = VmClient( + account=account, + node_url=node_url, + session=client.session, + ) + + status_code, response_text = await vm_client.stop_instance(vm_id) + assert status_code == 200, response_text + assert response_text == "ok" + + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_websocket_authentication(aiohttp_client): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + async def websocket_handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + first_message = await ws.receive_json() + credentials = first_message["auth"] + sender_address = await authenticate_websocket_message( + credentials, + domain_name=urlparse(node_url).hostname, + ) + + assert vm_client.account.get_address() == sender_address + await ws.send_str(sender_address) + + return ws + + app = web.Application() + app.router.add_route( + "GET", "/control/machine/{vm_id}/stream_logs", websocket_handler + ) # Update route to match the URL + + client = await aiohttp_client(app) + + node_url = str(client.make_url("")).rstrip("/") + + vm_client = VmClient( + account=account, + node_url=node_url, + session=client.session, + ) + + valid = False + + async for address in vm_client.get_logs(vm_id): + assert address == vm_client.account.get_address() + valid = True + + # this is done to ensure that the ws as runned at least once and avoid + # having silent errors + assert valid + + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_vm_client_generate_correct_authentication_headers(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + + vm_client = VmClient( + account=account, + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + + path, headers = await vm_client._generate_header(vm_id, "reboot", method="post") + signed_pubkey = SignedPubKeyHeader.model_validate_json(headers["X-SignedPubKey"]) + signed_operation = SignedOperation.model_validate_json(headers["X-SignedOperation"]) + address = verify_signed_operation(signed_operation, signed_pubkey) + + assert vm_client.account.get_address() == address diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py new file mode 100644 index 00000000..6c5e01ed --- /dev/null +++ b/tests/unit/test_vm_confidential_client.py @@ -0,0 +1,220 @@ +import tempfile +from pathlib import Path +from unittest import mock +from unittest.mock import patch + +import aiohttp +import pytest +from aioresponses import aioresponses +from aleph_message.models import ItemHash + +from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.client.vm_confidential_client import VmConfidentialClient + + +@pytest.mark.asyncio +async def test_perform_confidential_operation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/test" + + with aioresponses() as m: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/{operation}", + status=200, + payload="mock_response_text", + ) + + response_text = await vm_client.perform_confidential_operation(vm_id, operation) + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_initialize_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/initialize" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_bytes = Path(tmp_file.name).read_bytes() + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + tmp_file_path = Path(tmp_file.name) + response_text = await vm_client.initialize( + vm_id, session=tmp_file_path, godh=tmp_file_path + ) + assert ( + response_text == '"mock_response_text"' + ) # ' ' cause by aioresponses + m.assert_called_once_with( + url, + method="POST", + data={ + "session": tmp_file_bytes, + "godh": tmp_file_bytes, + }, + json=None, + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_measurement_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/measurement" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.get( + url, + status=200, + payload=dict( + { + "sev_info": { + "enabled": True, + "api_major": 0, + "api_minor": 0, + "build_id": 0, + "policy": 0, + "state": "", + "handle": 0, + }, + "launch_measure": "test_measure", + } + ), + ) + measurement = await vm_client.measurement(vm_id) + assert measurement.launch_measure == "test_measure" + m.assert_called_once_with( + url, + method="GET", + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_inject_secret_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/inject_secret" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + test_secret = "test_secret" + packet_header = "test_packet_header" + + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + response_text = await vm_client.inject_secret( + vm_id, secret=test_secret, packet_header=packet_header + ) + assert response_text == "mock_response_text" + m.assert_called_once_with( + url, + method="POST", + json={ + "secret": test_secret, + "packet_header": packet_header, + }, + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_create_session_command(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + node_url = "http://localhost" + sevctl_path = Path("/usr/bin/sevctl") + certificate_prefix = ( + "cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe/vm" + ) + platform_certificate_path = Path("/") + policy = 1 + + with mock.patch( + "aleph.sdk.client.vm_confidential_client.run_in_subprocess", + return_value=True, + ) as export_mock: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=sevctl_path, + node_url=node_url, + session=aiohttp.ClientSession(), + ) + _ = await vm_client.create_session( + certificate_prefix, platform_certificate_path, policy + ) + export_mock.assert_called_once_with( + [ + str(sevctl_path), + "session", + "--name", + certificate_prefix, + str(platform_certificate_path), + str(policy), + ], + check=True, + ) diff --git a/tests/unit/test_wallet_ethereum.py b/tests/unit/test_wallet_ethereum.py index 0f798c9d..f7ca2157 100644 --- a/tests/unit/test_wallet_ethereum.py +++ b/tests/unit/test_wallet_ethereum.py @@ -23,7 +23,7 @@ async def test_ledger_eth_account(): address = account.get_address() assert address - assert type(address) is str + assert isinstance(address, str) assert len(address) == 42 message = Message("ETH", account.get_address(), "SomeType", "ItemHash")