Skip to content

kswain98/platonic-rep

 
 

Repository files navigation

The Platonic Representation Hypothesis

paper            project page

minyoung huh*       brian cheung*       tongzhou wang*       phillip isola*      

Requirements


Developed on

python = 3.11 PyTorch = 2.2.0

You can install the rest of the requirements via

pip install -r requirements.txt

Running alignment


(1) Extracting features

First, we extract features from the models.

# extract all language model features and pool them along each block
python extract_features.py --dataset minhuh/prh --subset wit_1024 --modelset val --modality language --pool avg

# Extract last layer features of all vision models
python extract_features.py --dataset minhuh/prh --subset wit_1024 --modelset val --modality vision --pool cls

The resulting features are stored in ./results/features

(2) Measuring vision-language alignment

After extracting the features, you can compute the alignment score by

python measure_alignment.py --dataset minhuh/prh --subset wit_1024 --modelset val \
        --modality_x language --pool_x avg --modality_y vision --pool_y cls

The resulting alignment scores will be stored in ./results/alignment

>>> fp = './results/alignment/minhuh/prh/val/language_pool-avg_prompt-False_vision_pool-cls_prompt-False/mutual_knn_k10.npy'
>>> result = np.load(fp, allow_pickle=True).item()
>>> print(results.keys()
dict_keys(['scores', 'indices'])
>>> print(result['scores'].shape) # 12 language models x 17 vision models
(12, 17)

Scoring your own model for alignment to Platonic Representation Hypothesis

We provide code to compute alignment scores for your own model while training/evaluating.

(1) Install library as pip package First install the library as a pip package

pip install -e .

(2) Initiate the metric scoring function

import platonic

# setup platonic metric
platonic_metric = platonic.Alignment(
                    dataset="minhuh/prh",
                    subset="wit_1024", 
                    models=["dinov2_g", "clip_h"],
                    ) # optional arguments device, dtype, save_dir (or path to your features)

# load texts
texts = platonic_metric.get_data(modality="text")

We provide some precomputed features, so you don't have to compute it yourself. It will automatically download them for you. See SUPPORTED_DATASETS in platonic.py. Note: We will add more in the upcoming weeks.

(3) Extract the features from your model

# your model has to have `output_hidden_states=True`
with torch.no_grad():
        llm_output = language_model(
            input_ids=token_inputs["input_ids"],
            attention_mask=token_inputs["attention_mask"],
        )
        feats = torch.stack(llm_output["hidden_states"]).permute(1, 0, 2, 3)

# using average pooling (only on valid tokens)
mask = token_inputs["attention_mask"].unsqueeze(-1).unsqueeze(1)
feats = (feats * mask).sum(2) / mask.sum(2)

# compute score. the score is dict for each model where each entry contains the (scores, maximal alignment layer indices)
score = platonic_metric.score(feats, metric="mutual_knn", topk=10, normalize=True)

We provide examples for both vision and language in examples. You can run them via python examples/example_language.py. It will download the features in the local directory if you don't have it computed already.


Customization / Questions


❔ Can I add additional models?

To add your own set of models, add them and correctly modify the files in tasks.py. The llm_models should be auto-regressive models from huggingface and lvm_models should be ViT models from huggingface/timm. Most models should work without further modification. Currently, we do not support different vision architectures and language models that are not autoregressive.



❔ What are the metrics that I can use?

To check all supported alignment metrics run

>>> python -c 'from metrics import AlignmentMetrics; print(AlignmentMetrics.SUPPORTED_METRICS)'
['cycle_knn', 'mutual_knn', 'lcs_knn', 'cka', 'unbiased_cka', 'cknna', 'svcca', 'edit_distance_knn']

Feel free to add your own in metrics.py



❔ I want to use the metrics for my own repo. How do I use it?
Simply copy the metrics.py file to your repo, and you can use it anywhere. It expects a tensor of shape [batch x feats]

from metrics import AlignmentMetrics
import torch.nn.functional as F

feats_A = torch.randn(64, 8192)
feats_B = torch.randn(64, 8192)
feats_A = F.normalize(feats_A, dim=-1)
feats_B = F.normalize(feats_B, dim=-1)

# measure score
score = AlignmentMetrics.measure('cknna', feats_A, feats_B, topk=10)

# alternative
score = AlignmentMetrics.cknna(feats_A, feats_B, topk=10)


❔ I want to add my own custom features for platonic
To add custom models, add it to SUPPORTED_DATASETS.



❔ Download URL is down. What do I do?
If our download URL is down, please give it some time, as we will try to set it back up as soon as possible. In the meantime, you can compute the same exact features by running the example code in the extracting features section above.



❔ Reporting alignment scores

Note that numbers might vary with different precision and batch-size due to hardware/algorithm variabilities. When evaluating alignment trends, we recommend you to regenerate the features using the same settings when reporting numbers.



Citation


@inproceedings{huh2024prh,
  title={The Platonic Representation Hypothesis},
  author={Huh, Minyoung and Cheung, Brian and Wang, Tongzhou, and Isola, Phillip},
  booktitle={International Conference on Machine Learning},
  year={2024}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%