Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variant/kraken #99

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
added readout tokenizers and fixed discrete head
  • Loading branch information
andrearosasco committed Apr 22, 2024
commit ca9ab7e74099e07c88bf65bd25df3ff4ea908870
2 changes: 2 additions & 0 deletions octo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def restructure(traj):

# add timestep info
new_obs["timestep"] = tf.range(traj_len)
new_obs['next_action'] = old_obs['next_action']

# extracts `language_key` into the "task" dict
task = {}
Expand Down Expand Up @@ -380,6 +381,7 @@ def restructure(traj):
for filter_fcn_spec in filter_functions:
full_dataset = full_dataset.filter(ModuleSpec.instantiate(filter_fcn_spec))
full_dataset = full_dataset.traj_map(restructure, num_parallel_calls)

# tries to load from cache, otherwise computes on the fly
dataset_statistics = get_dataset_statistics(
full_dataset,
Expand Down
4 changes: 2 additions & 2 deletions octo/model/components/action_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def discrete_loss(
labels = discrete_tokenizer(ground_truth_value)
labels_one_hot = jax.nn.one_hot(labels, logits.shape[-1])

loss = -jnp.sum(logits * labels_one_hot, axis=-1)
loss = masked_mean(loss, mask)
loss = jnp.sum(jax.nn.log_softmax(logits, axis=-1) * labels_one_hot, axis=-1)
loss = -masked_mean(loss, mask)

# compute accuracy between predicted actions and target actions
pred_label = jnp.argmax(logits, axis=-1)
Expand Down
2 changes: 2 additions & 0 deletions octo/model/octo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def load_pretrained(
tf.io.gfile.join(checkpoint_path, "config.json"), "r"
) as f:
config = json.load(f)
if 'readouts' in config['model']:
config['model']['readout_tokenizers'] = config['model'].pop('readouts')

# load example batch
with tf.io.gfile.GFile(
Expand Down
94 changes: 62 additions & 32 deletions octo/model/octo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class OctoTransformer(nn.Module):

observation_tokenizers: Dict[str, nn.Module]
task_tokenizers: Dict[str, nn.Module]
readouts: Dict[str, int]
readout_tokenizers: Dict[str, int | nn.Module]
transformer_kwargs: Dict
token_embedding_size: int
max_horizon: int
Expand All @@ -88,7 +88,7 @@ def __call__(
observations: Data,
tasks: Data,
pad_mask: jax.Array,
readouts: Optional[Sequence[str]] = None,
readout_tokenizers: Optional[Sequence[str]] = None,
train: bool = False,
verbose: bool = False,
) -> Dict[str, TokenGroup]:
Expand All @@ -110,15 +110,15 @@ def __call__(

Note: Horizon can be anything <= max_horizon.
"""
if readouts is None:
readouts = list(self.readouts.keys())
if readout_tokenizers is None:
readout_tokenizers = list(self.readout_tokenizers.keys())

#
# Check that all inputs are valid
#

assert set(readouts).issubset(
set(self.readouts.keys())
assert set(readout_tokenizers).issubset(
set(self.readout_tokenizers.keys())
), "readouts must be specified in the model config"

batch_size, horizon = jax.tree_util.tree_leaves(observations)[0].shape[:2]
Expand Down Expand Up @@ -213,32 +213,58 @@ def __call__(
# Finally, add the readout tokens
#

for readout_name in readouts:
group_name = f"readout_{readout_name}"
# Readouts do not correspond to any inputs, just positional embeddings
n_tokens_for_readout = self.readouts[readout_name]
readout_tokens = jnp.zeros(
(batch_size, horizon, n_tokens_for_readout, self.token_embedding_size)
)

# Add positional embedding
readout_tokens += self._create_positional_embedding(
group_name, readout_tokens
)
readout_mask = jnp.ones((batch_size, horizon, n_tokens_for_readout))
readout_attention_rules = {
"task_*": AttentionRule.CAUSAL,
"obs_*": AttentionRule.CAUSAL,
group_name: AttentionRule.CAUSAL,
} # Attend to tasks, all previous observations, and *only it's own own readout*
for name, tok in self.readout_tokenizers.items():
group_name = f"readout_{name}"
if isinstance(tok, nn.Module):
tokenizer_output: TokenGroup = tok(observations, tasks, train=train)
if tokenizer_output is None:
logging.warning(f"Skipping observation tokenizer: {group_name}")
continue

obs_tokens = nn.Dense(
self.token_embedding_size, name=f"{group_name}_projection"
)(tokenizer_output.tokens)
# obs_tokens shape is (batch, horizon, n_tokens, token_embedding_size)

# Add positional embedding
obs_tokens += self._create_positional_embedding(group_name, obs_tokens)

# Update mask to account for which timesteps are padding
obs_pad_mask = jnp.logical_and(pad_mask[:, :, None], tokenizer_output.mask)

all_timestep_groups.append(
TimestepGroup(
tokens=obs_tokens,
mask=obs_pad_mask,
name=group_name,
attention_rules=observation_attention_rules,
)
)
elif isinstance(tok, int):
# Readouts do not correspond to any inputs, just positional embeddings
n_tokens_for_readout = self.readout_tokenizers[name]
readout_tokens = jnp.zeros(
(batch_size, horizon, n_tokens_for_readout, self.token_embedding_size)
)

all_timestep_groups.append(
TimestepGroup(
tokens=readout_tokens,
mask=readout_mask,
name=group_name,
attention_rules=readout_attention_rules,
# Add positional embedding
readout_tokens += self._create_positional_embedding(
group_name, readout_tokens
)
readout_mask = jnp.ones((batch_size, horizon, n_tokens_for_readout))
readout_attention_rules = {
"task_*": AttentionRule.CAUSAL,
"obs_*": AttentionRule.CAUSAL,
group_name: AttentionRule.CAUSAL,
} # Attend to tasks, all previous observations, and *only it's own own readout*

all_timestep_groups.append(
TimestepGroup(
tokens=readout_tokens,
mask=readout_mask,
name=group_name,
attention_rules=readout_attention_rules,
)
)

# Run the transformer!
Expand Down Expand Up @@ -341,7 +367,7 @@ def create(
observation_tokenizers: Dict[str, ModuleSpec],
task_tokenizers: Dict[str, ModuleSpec],
heads: Dict[str, ModuleSpec],
readouts: Dict[str, int],
readout_tokenizers: Dict[str, int | ModuleSpec],
transformer_kwargs: Dict,
token_embedding_size: int,
max_horizon: int,
Expand Down Expand Up @@ -372,13 +398,17 @@ def create(
task_tokenizer_defs = {
k: ModuleSpec.instantiate(spec)() for k, spec in task_tokenizers.items()
}
readout_tokenizer_defs = {
k: ModuleSpec.instantiate(spec)() if isinstance(spec, dict) else spec
for k, spec in readout_tokenizers.items()
}

head_defs = {k: ModuleSpec.instantiate(spec)() for k, spec in heads.items()}

model_def = OctoTransformer(
observation_tokenizers=observation_tokenizer_defs,
task_tokenizers=task_tokenizer_defs,
readouts=readouts,
readout_tokenizers=readout_tokenizer_defs,
token_embedding_size=token_embedding_size,
max_horizon=max_horizon,
transformer_kwargs=transformer_kwargs,
Expand Down