This is the official repo for the paper: Concept Bottleneck Large Language Models.
- In this work, we proposed Concept Bottleneck Large Language Model (CB-LLM), the first framework for building inherently interpretable Large Language Models (LLMs) that works on both text generation and text classification tasks. CB-LLM extends and generalizes our earlier research, Crafting Large Language Models for Enhanced Interpretability for text classification tasks, offering both interpretability and controllability in text generation.
- This repo contains two parts:
- CB-LLM (classification): Transforming pre-trained LLMs into interpretable LLMs for text classification
- CB-LLM (generation): Transforming pre-trained LLMs into interpretable LLMs for text generation
Recommend using cuda12.1, python3.10, pytorch2.2. Go into the folder for classification case:
cd classification
Install the packages:
pip install -r requirements.txt
We also provide finetuned CB-LLMs, allowing you to skip the training process. Download the checkpoints from huggingface:
git lfs install
git clone https://huggingface.co/cesun/cbllm-classification temp_repo
mv temp_repo/mpnet_acs .
rm -rf temp_repo
To generate the concept scores by our ACS strategy, run
python get_concept_labels.py
This will generate the concept scores for the SST2 dataset using our predefined concept set, and store the scores under mpnet_acs/SetFit_sst2/
.
Set the argument --dataset yelp_polarity
, --dataset ag_news
, or --dataset dbpedia_14
to switch the dataset.
To train the CBL, run
python train_CBL.py --automatic_concept_correction
This will train the CBL with Automatic Concept Correction for the SST2 dataset, and store the model under mpnet_acs/SetFit_sst2/roberta_cbm/
.
To disable Automatic Concept Correction, remove the given argument.
Set the argument --backbone gpt2
to switch the backbone from roberta to gpt2.
Set the argument --dataset yelp_polarity
, --dataset ag_news
, or --dataset dbpedia_14
to switch the dataset.
To train the final predictor, run
python train_FL.py --cbl_path mpnet_acs/SetFit_sst2/roberta_cbm/cbl_acc.pt
This will train the linear predictor of the CBL for the SST2 dataset, and store the linear layer in the same directory.
Please change the argument --cbl_path
accordingly if using other settings. For example, w/o Automatic Concept Correction will be save as cbl.pt
.
To train the baseline standard black-box model, run
python finetune_black_box.py
This will train the black-box (non-interpretable) model for the SST2 dataset, and store the model under baseline_models/roberta/
.
Set the argument --backbone gpt2
to switch the backbone from roberta to gpt2.
Set the argument --dataset yelp_polarity
, --dataset ag_news
, or --dataset dbpedia_14
to switch the dataset.
To test the accuracy of the CB-LLM, run
python test_CBLLM.py --cbl_path mpnet_acs/SetFit_sst2/roberta_cbm/cbl_acc.pt
Please change the argument --cbl_path
accordingly if using other settings. For example, w/o Automatic Concept Correction will be save as cbl.pt
. Add the --sparse
argument for testing with the sparse final layer.
To visualize the neurons in CB-LLM (task 1 in our paper), run
python print_concept_activations.py --cbl_path mpnet_acs/SetFit_sst2/roberta_cbm/cbl_acc.pt
This will generate 5 most related samples for each neuron explanation.
Please change the argument --cbl_path
accordingly if using other settings.
To get the explanations provided by CB-LLM (task 2 in our paper), run
python print_concept_contributions.py --cbl_path mpnet_acs/SetFit_sst2/roberta_cbm/cbl_acc.pt
This will generate 5 explanations for each sample in the dataset.
Please change the argument --cbl_path
accordingly if using other settings.
To test the accuracy of the baseline standard black-box model, run
python test_black_box.py --model_path baseline_models/roberta/backbone_finetuned_sst2.pt
Set the argument --dataset yelp_polarity
, --dataset ag_news
, or --dataset dbpedia_14
to switch the dataset.
Please change the argument --model_path
accordingly if using other settings.
Accuracy ↑ | SST2 | YelpP | AGnews | DBpedia |
---|---|---|---|---|
Ours: | ||||
CB-LLM | 0.9012 | 0.9312 | 0.9009 | 0.9831 |
CB-LLM w/ ACC | 0.9407 | 0.9806 | 0.9453 | 0.9928 |
Baselines: | ||||
TBM&C³M | 0.9270 | 0.9534 | 0.8972 | 0.9843 |
Roberta-base fine-tuned (black-box) | 0.9462 | 0.9778 | 0.9508 | 0.9917 |
Recommend using cuda12.1, python3.10, pytorch2.2. Go into the folder for generation case:
cd generation
Install the packages:
pip install -r requirements.txt
We also provide finetuned CB-LLMs, allowing you to skip the training process. Download the checkpoints from huggingface:
git lfs install
git clone https://huggingface.co/cesun/cbllm-generation temp_repo
mv temp_repo/from_pretained_llama3_lora_cbm .
rm -rf temp_repo
To train the CB-LLM for text generation, run
python train_CBLLM.py
This will train the CB-LLM (Lora finetune Llama3 with CBL) on SST2 dataset with the class labels as concepts (negative or positive), and store the model under from_pretained_llama3_lora_cbm/SetFit_sst2/
.
Set the argument --dataset yelp_polarity
, --dataset ag_news
, or --dataset dbpedia_14
to switch the dataset.
To test the concept detection (concept accuracy) of the CB-LLM, run
python test_concepts.py
Please rename the desired checkpoint of the peft model and CBL as llama3
and cbl.pt
, as the script recognizes these file names.
Set the argument --dataset yelp_polarity
, --dataset ag_news
, or --dataset dbpedia_14
to switch the dataset.
To test the steerability (concept accuracy) of the CB-LLM, need to first train the roberta classifier (to determine whether the generated text belongs to the desired class)
python train_classifier.py
Set the argument --dataset yelp_polarity
, --dataset ag_news
, or --dataset dbpedia_14
to switch the dataset.
After getting the classifier corresponding to the dataset, evaluate the steerability by running
python test_steerability.py
Please rename the desired checkpoint of the peft model and CBL as llama3
and cbl.pt
, as the script recognizes these file names.
To test the perplexity using Llama3-8B, run
python test_perplexity.py
Please rename the desired checkpoint of the peft model and CBL as llama3
and cbl.pt
, as the script recognizes these file names.
Set the argument --dataset yelp_polarity
, --dataset ag_news
, or --dataset dbpedia_14
to switch the dataset.
Run
python test_weight.py
Set the argument --dataset yelp_polarity
, --dataset ag_news
, or --dataset dbpedia_14
to switch the dataset.
Paste the results to SankeyMATIC to visualize the weights. For example,
Run
python test_generation.py
By changing the activation value of the corresponding neuron in Line 48, the generation would contain the desired concepts.
Set the argument --dataset yelp_polarity
, --dataset ag_news
, or --dataset dbpedia_14
to switch the dataset.
The accuracy, steerability, and perplexity of CB-LLMs (generation). CB-LLMs perform well on accuracy (↑) and perplexity (↓) while providing higher steerability (↑)
Method | Metric | SST2 | YelpP | AGnews | DBpedia |
---|---|---|---|---|---|
CB-LLM (Ours) | Accuracy↑ | 0.9638 | 0.9855 | 0.9439 | 0.9924 |
Steerability↑ | 0.82 | 0.95 | 0.85 | 0.58 | |
Perplexity↓ | 116.22 | 13.03 | 18.25 | 37.59 | |
CB-LLM w/o ADV training | Accuracy↑ | 0.9676 | 0.9830 | 0.9418 | 0.9934 |
Steerability↑ | 0.57 | 0.69 | 0.52 | 0.21 | |
Perplexity↓ | 59.19 | 12.39 | 17.93 | 35.13 | |
Llama3 finetuned (black-box) | Accuracy↑ | 0.9692 | 0.9851 | 0.9493 | 0.9919 |
Steerability↑ | No | No | No | No | |
Perplexity↓ | 84.70 | 6.62 | 12.52 | 41.50 |
Chung-En Sun, Tuomas Oikarinen, Berk Ustun, and Tsui-Wei Weng. "Concept Bottleneck Large Language Models". arXiv preprint, 2024
@article{cbllm,
title={Concept Bottleneck Large Language Models},
author={Chung-En Sun, Tuomas Oikarinen, Berk Ustun, Tsui-Wei Weng},
journal={arXiv 2024},
year={2024}
}