Skip to content
/ kl-rb Public

This repository contains code for the paper "Better Estimation of the KL Divergence Between Language Models"

Notifications You must be signed in to change notification settings

rycolab/kl-rb

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 

Repository files navigation

Better Estimation of the KL Divergence Between Language Models

This repository contains code for the Rao--Blackwellized Monte Carlo estimator of the KL divergence between two language models.

The code is based on the trl library. To run the code, you need to replace the the corresponding files in the trl library with the ones in this repository.

Our Estimator

Our RB estimator is implemented in trl/trainer/utils.py in compute_kl function, with a few lines of code.

def compute_kl(new_logprobs, ref_logprobs, logits_p=None, logits_q=None):
    if logits_p is not None:
        logp = torch.log_softmax(logits_p, dim=-1)
        logq = torch.log_softmax(logits_q, dim=-1)

        return torch.sum(torch.exp(logp) * (logp - logq), dim=-1)
    return new_logprobs - ref_logprobs

Note that if logits_p and logits_q are not None, the KL divergence is computed using the log probabilities of the logits. Otherwise, the KL divergence is computed using the RB estimator, and is otherwise set to the MC estimator.

Integration with RLOO

We can use this estimator either to evaluate RLHFed models, or to use it in the RL loop. We modify the rloo trainer in the trl library to use this RB estimator. This includes the modified trainer in trainer/rloo_trainer.py and a new config in trainer/rloo_config.py, which add stepwise to the RLOO trainer config.

We further provide an example of how to run RLHF with this modified version in rloo_sentiment.py. An example command to run the example script is

python rloo_sentiment.py \
	--init_kl_coef 0.05 \
	--method [MC, RB] \
	--seed 0 \
	--filepath /data
	--filename imdb
  • --init_kl_coef argument specifies the initial KL coefficient
  • --method argument specifies the method to use, either MC for Monte Carlo or RB for Rao--Blackwellized
  • --seed argument specifies the random seed for reproducibility
  • --filepath argument specifies the path to the dataset
  • --filename argument specifies the name of the dataset

About

This repository contains code for the paper "Better Estimation of the KL Divergence Between Language Models"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages