Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Discussion options

Using TPUStrategy and tensorflow backend for training on TPUs with keras 2. Now upgrading to keras 3. Can a comprehensive example please be provided using jax backend for multi-host multi-tpu. Ideally something that scales v4-256 and above.

Interested in data parallel primarily. Have used runtime tpu-vm-tf-2.15.0-pod-pjrt. Have no idea what this needs to be for jax backend. How does user VM, which may be distinct from the TPU hosts connect?

Currently that is all handled by code such as:

gcloud alpha compute tpus tpu-vm create $TPU_NAME \
        --zone=$TPU_ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$VERSION
def get_tpu_strategy() -> tf.distribute.Strategy:
    """Get tensorflow distribution strategy for TPU

    Returns
    -------
    tf.distribute.Strategy
        Model distribution strategy for TPU

    """
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
        tpu=os.environ["TPU_NAME"],
        zone=os.environ["TPU_ZONE"],
        project=os.environ["GCP_PROJECT"],
    )
    logger.info(f"All TPU devices: {resolver.cluster_spec().as_dict()['worker']}")
    tf.config.experimental_connect_to_cluster(resolver)
    logger.info("Connected to TPUCluster")
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)
    logger.info(f"Number of TPU accelerators: {strategy.num_replicas_in_sync}")
    strategy.num_workers = len(
        strategy._extended._get_input_workers(options=None)._worker_device_pairs
    )
    return strategy
You must be logged in to vote

Replies: 0 comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
1 participant
Morty Proxy This is a proxified and sanitized view of the page, visit original site.