Skip to content

Commit

Permalink
Bitsfit (#27)
Browse files Browse the repository at this point in the history
* ✨feat:add bitfit

---------

Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Jul 3, 2023
1 parent 7ba40c5 commit 1ea923f
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 12 deletions.
9 changes: 8 additions & 1 deletion scripts/train_m3e.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from uniem.trainer import Trainer
from uniem.types import MixedPrecisionType
from uniem.utils import create_adamw_optimizer
from uniem.utils import apply_bitfit, convert_to_readable_string, create_adamw_optimizer

app = typer.Typer()

Expand Down Expand Up @@ -47,6 +47,7 @@ def main(
model_name_or_path: str,
m3e_datasets_dir: Path,
# Model
model_class: Annotated[Optional[str], typer.Option(rich_help_panel='Model')] = None,
temperature: Annotated[float, typer.Option(rich_help_panel='Model')] = 0.05,
loss_type: Annotated[InBatchNegLossType, typer.Option(rich_help_panel='Model')] = InBatchNegLossType.softmax,
embedding_strategy: Annotated[PoolingStrategy, typer.Option(rich_help_panel='Model')] = PoolingStrategy.last_mean,
Expand All @@ -61,6 +62,7 @@ def main(
num_warmup_steps: Annotated[float, typer.Option(rich_help_panel='Optimizer')] = 0.05,
# Trainer
epochs: Annotated[int, typer.Option(rich_help_panel='Trainer')] = 3,
bitfit: Annotated[bool, typer.Option(rich_help_panel='Trainer')] = False,
mixed_precision: Annotated[MixedPrecisionType, typer.Option(rich_help_panel='Trainer')] = MixedPrecisionType.no,
gradient_accumulation_steps: Annotated[int, typer.Option(rich_help_panel='Trainer')] = 1,
save_on_epoch_end: Annotated[bool, typer.Option(rich_help_panel='Trainer')] = False,
Expand Down Expand Up @@ -114,10 +116,15 @@ def main(

model = EmbedderForPairInBatchNegTrain(
model_name_or_path=model_name_or_path,
model_class=model_class,
temperature=temperature,
loss_type=loss_type,
embedding_strategy=embedding_strategy,
)
if bitfit:
apply_bitfit(model)
num_training_paramters = sum(p.numel() for p in model.parameters() if p.requires_grad)
accelerator.print(f'Number of training parameters: {convert_to_readable_string(num_training_paramters)}')
model.embedder.encoder.config.pad_token_id = tokenizer.pad_token_id
model = accelerator.prepare(model)

Expand Down
8 changes: 8 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from uniem.utils import convert_to_readable_string


def test_convert_to_readable_string():
assert convert_to_readable_string(123) == '123'
assert convert_to_readable_string(1234) == '1.2k'
assert convert_to_readable_string(1234567) == '1.2M'
assert convert_to_readable_string(1234567890) == '1.2B'
28 changes: 17 additions & 11 deletions uniem/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import importlib
from enum import Enum
from pathlib import Path
from typing import ClassVar, Literal, Type, TypeVar, cast

import numpy as np
import torch
import tqdm
from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel # type: ignore
from transformers import AutoModel, AutoTokenizer, PreTrainedModel # type: ignore

from uniem.criteria import (
CoSentLoss,
Expand Down Expand Up @@ -47,15 +48,18 @@ def mean_pooling(hidden_state: torch.Tensor, attention_mask: torch.Tensor | None
return torch.sum(hidden_state * attention_mask.unsqueeze(-1), dim=1) / torch.sum(attention_mask, dim=-1, keepdim=True)


def load_hf_pretrained_model(model_name_or_path: str) -> PreTrainedModel:
config = AutoConfig.from_pretrained(model_name_or_path)
if config.model_type == 't5':
from transformers import T5EncoderModel # type: ignore
def load_hf_pretrained_model(
model_name_or_path: str, model_class: str | None | Type[PreTrainedModel] | Type[AutoModel] = None
) -> PreTrainedModel:
if model_class is None:
model_class = AutoModel
elif isinstance(model_class, str):
transformers_module = importlib.import_module('transformers')
model_class = getattr(transformers_module, model_class)

pretrained_model = T5EncoderModel.from_pretrained(model_name_or_path)
else:
pretrained_model = AutoModel.from_pretrained(model_name_or_path)
return pretrained_model # type: ignore
model = model_class.from_pretrained(model_name_or_path) # type: ignore
model = cast(PreTrainedModel, model)
return model


StrategyEmbedderClsMap: dict[PoolingStrategy, Type['Embedder']] = {}
Expand Down Expand Up @@ -191,11 +195,12 @@ class EmbedderForPairInBatchNegTrain(EmbedderForTrain):
def __init__(
self,
model_name_or_path: str,
model_class: str | None = None,
temperature: float | None = None,
loss_type: InBatchNegLossType | str = InBatchNegLossType.softmax,
embedding_strategy: PoolingStrategy | str = PoolingStrategy.last_mean,
):
pretrained_model = load_hf_pretrained_model(model_name_or_path)
pretrained_model = load_hf_pretrained_model(model_name_or_path, model_class=model_class)
embedder = StrategyEmbedderClsMap[PoolingStrategy(embedding_strategy)](pretrained_model)
super().__init__(embedder)
temperature = temperature or 0.05
Expand All @@ -219,12 +224,13 @@ class EmbedderForTripletInBatchNegTrain(EmbedderForTrain):
def __init__(
self,
model_name_or_path: str,
model_class: str | None = None,
temperature: float | None = None,
loss_type: InBatchNegLossType | str = InBatchNegLossType.softmax,
embedding_strategy: PoolingStrategy | str = PoolingStrategy.last_mean,
add_swap_loss: bool = False,
):
pretrained_model = load_hf_pretrained_model(model_name_or_path)
pretrained_model = load_hf_pretrained_model(model_name_or_path, model_class=model_class)
embedder = StrategyEmbedderClsMap[PoolingStrategy(embedding_strategy)](pretrained_model)
super().__init__(embedder)
temperature = temperature or 0.05
Expand Down
19 changes: 19 additions & 0 deletions uniem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def generate_batch(data: Iterable[T], batch_size: int = 32) -> Generator[list[T]
yield batch


def apply_bitfit(model: torch.nn.Module):
for name, param in model.named_parameters():
if 'bias' in name:
param.requires_grad = True
else:
param.requires_grad = False


def split_dataset_dict(dataset_dict: dict[str, T]) -> tuple[T, T | None]:
if isinstance(dataset_dict, dict):
train_dataset = dataset_dict['train']
Expand Down Expand Up @@ -85,3 +93,14 @@ def decorator(*args, **kwargs):
raise

return decorator


def convert_to_readable_string(number: float) -> str:
if number >= 1e9:
return f'{number / 1e9:.1f}B'
elif number >= 1e6:
return f'{number / 1e6:.1f}M'
elif number >= 1e3:
return f'{number / 1e3:.1f}k'
else:
return str(number)

0 comments on commit 1ea923f

Please sign in to comment.