diff --git a/uniswap/constants.py b/uniswap/constants.py index 60fd4c3..58b4522 100644 --- a/uniswap/constants.py +++ b/uniswap/constants.py @@ -1,3 +1,17 @@ +from typing import Set, cast +from web3.types import ( # noqa: F401 + RPCEndpoint, +) + +# look at web3/middleware/cache.py for reference +# RPC methods that will be cached inside _get_eth_simple_cache_middleware +SIMPLE_CACHE_RPC_WHITELIST = cast( + Set[RPCEndpoint], + { + "eth_chainId", + }, +) + ETH_ADDRESS = "0x0000000000000000000000000000000000000000" WETH9_ADDRESS = "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2" diff --git a/uniswap/uniswap.py b/uniswap/uniswap.py index ae5ad16..8660f2d 100644 --- a/uniswap/uniswap.py +++ b/uniswap/uniswap.py @@ -23,6 +23,7 @@ from .token import ERC20Token from .exceptions import InvalidToken, InsufficientBalance from .util import ( + _get_eth_simple_cache_middleware, _str_to_addr, _addr_to_str, _validate_address, @@ -78,6 +79,7 @@ def __init__( # use_eip1559: bool = True, factory_contract_addr: str = None, router_contract_addr: str = None, + enable_caching: bool = False, ) -> None: """ :param address: The public address of the ETH wallet to use. @@ -88,6 +90,7 @@ def __init__( :param default_slippage: Default slippage for a trade, as a float (0.01 is 1%). WARNING: slippage is untested. :param factory_contract_addr: Can be optionally set to override the address of the factory contract. :param router_contract_addr: Can be optionally set to override the address of the router contract (v2 only). + :param enable_caching: Optionally enables middleware caching RPC method calls. """ self.address = _str_to_addr( address or "0x0000000000000000000000000000000000000000" @@ -115,7 +118,9 @@ def __init__( provider = os.environ["PROVIDER"] self.w3 = Web3(Web3.HTTPProvider(provider, request_kwargs={"timeout": 60})) - # Cache netid to avoid extra RPC calls + if enable_caching: + self.w3.middleware_onion.inject(_get_eth_simple_cache_middleware(), layer=0) + self.netid = int(self.w3.net.version) if self.netid in _netid_to_name: self.netname = _netid_to_name[self.netid] diff --git a/uniswap/util.py b/uniswap/util.py index 2f65ab2..6826812 100644 --- a/uniswap/util.py +++ b/uniswap/util.py @@ -2,16 +2,27 @@ import json import math import functools -from typing import Any, Generator, Sequence, Union, List, Tuple +import lru + +from typing import Any, Generator, Sequence, Union, List, Tuple, Type, Dict, cast from web3 import Web3 from web3.exceptions import NameNotFound from web3.contract import Contract +from web3.middleware.cache import construct_simple_cache_middleware +from web3.types import Middleware -from .constants import MIN_TICK, MAX_TICK, _tick_spacing +from .constants import MIN_TICK, MAX_TICK, _tick_spacing, SIMPLE_CACHE_RPC_WHITELIST from .types import AddressLike, Address +def _get_eth_simple_cache_middleware() -> Middleware: + return construct_simple_cache_middleware( + cache_class=cast(Type[Dict[Any, Any]], functools.partial(lru.LRU, 256)), + rpc_whitelist=SIMPLE_CACHE_RPC_WHITELIST, + ) + + def _str_to_addr(s: Union[AddressLike, str]) -> Address: """Idempotent""" if isinstance(s, str):