Skip to content

Commit

Permalink
Add support for optional chat format to chat dataset builder (pytorch…
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Apr 22, 2024
1 parent bfd10bc commit 52567c9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
18 changes: 18 additions & 0 deletions docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ torchtune.datasets

.. currentmodule:: torchtune.datasets


Example datasets
----------------

torchtune supports several widely used datasets to help quickly bootstrap your fine-tuning

.. autosummary::
:toctree: generated/
:nosignatures:
Expand All @@ -15,3 +21,15 @@ torchtune.datasets
grammar_dataset
samsum_dataset
slimorca_dataset

Generic dataset builders
------------------------

torchtune also supports generic dataset builders for common formats like chat models and instruct models

.. autosummary::
:toctree: generated/
:nosignatures:

instruct_dataset
chat_dataset
11 changes: 6 additions & 5 deletions torchtune/datasets/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:

def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[int]]:
messages = self._convert_to_messages(sample, self.train_on_input)
if self.chat_format:
if self.chat_format is not None:
messages = self.chat_format.format(messages)
validate_messages(messages)
tokens, mask = self._tokenizer.tokenize_messages(
Expand All @@ -101,10 +101,11 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[in


def chat_dataset(
*,
tokenizer: Tokenizer,
source: str,
conversation_style: str,
chat_format: str,
chat_format: Optional[str] = None,
max_seq_len: int,
train_on_input: bool = False,
**load_dataset_kwargs: Dict[str, Any],
Expand All @@ -120,8 +121,8 @@ def chat_dataset(
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
conversation_style (str): string specifying expected style of conversations in the dataset
for automatic conversion to the llama style. Supported styles are: "sharegpt"
chat_format (str): name of ChatFormat class used to format the messages. See the description in
:class:`~torchtune.datasets.ChatDataset` for more details.
chat_format (Optional[str]): name of ChatFormat class used to format the messages. See the description in
:class:`~torchtune.datasets.ChatDataset` for more details. Default: None
max_seq_len (int): Maximum number of tokens in the returned input and label token id lists.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to `load_dataset`.
Expand All @@ -141,7 +142,7 @@ def chat_dataset(
tokenizer=tokenizer,
source=source,
convert_to_messages=convert_to_messages,
chat_format=_get_chat_format(chat_format),
chat_format=_get_chat_format(chat_format) if chat_format is not None else None,
max_seq_len=max_seq_len,
train_on_input=train_on_input,
**load_dataset_kwargs,
Expand Down

0 comments on commit 52567c9

Please sign in to comment.