This repository contains code for the Rao--Blackwellized Monte Carlo estimator of the KL divergence between two language models.
The code is based on the trl
library. To run the code, you need to replace the the corresponding files in the trl
library with the ones in this repository.
Our RB estimator is implemented in trl/trainer/utils.py
in compute_kl
function, with a few lines of code.
def compute_kl(new_logprobs, ref_logprobs, logits_p=None, logits_q=None):
if logits_p is not None:
logp = torch.log_softmax(logits_p, dim=-1)
logq = torch.log_softmax(logits_q, dim=-1)
return torch.sum(torch.exp(logp) * (logp - logq), dim=-1)
return new_logprobs - ref_logprobs
Note that if logits_p
and logits_q
are not None
, the KL divergence is computed using the log probabilities of the logits. Otherwise, the KL divergence is computed using the RB estimator, and is otherwise set to the MC estimator.
We can use this estimator either to evaluate RLHFed models, or to use it in the RL loop. We modify the rloo
trainer in the trl
library to use this RB estimator. This includes the modified trainer in trainer/rloo_trainer.py
and a new config in trainer/rloo_config.py
, which add stepwise
to the RLOO trainer config.
We further provide an example of how to run RLHF with this modified version in rloo_sentiment.py
. An example command to run the example script is
python rloo_sentiment.py \
--init_kl_coef 0.05 \
--method [MC, RB] \
--seed 0 \
--filepath /data
--filename imdb
--init_kl_coef
argument specifies the initial KL coefficient--method
argument specifies the method to use, eitherMC
for Monte Carlo orRB
for Rao--Blackwellized--seed
argument specifies the random seed for reproducibility--filepath
argument specifies the path to the dataset--filename
argument specifies the name of the dataset