Advice for training with TPU #2158
-
I'm training some timm models from scratch for face recognition. Recently I signed up for TPU Research Cloud. Do you have any advice for training timm models with TPU? Is PyTorch XLA the way to go, or re-writing the code in JAX is worth it for performance gains/compatibility? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thinking of using big_vision codebase to train ViT with TPUs. But I would need to convert the dataset to TFDS. Not sure if I can stream TFDS from huggingface (trying to avoid the costs with storing the data on GCS) |
Beta Was this translation helpful? Give feedback.
Technically it probably would be possible to hack something to http stream TFDS from .tfrecord shards in a HF dataset (just raw data afterall). But in any case, you wouldn't want to stream from the hub to GC for TPU training because it's too slow, you'd be wasting the TPUs. You need to copy your dataset to GCS for training with TPUs if you want any sort of reasonable performance and reliability.
I did have timm working quite well with TPUs + PyTorch XLA on an alternate branch with a different API I called
bits
https://github.com/huggingface/pytorch-image-models/tree/bits_and_tpu/timm/bits ... a few people were using it successfully at the time. However, I lost reliable access to TPUs and …