Skip to content

Commit

Permalink
process supervision
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 25, 2023
1 parent dbf3352 commit 05b4143
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 24 deletions.
26 changes: 3 additions & 23 deletions process_supervision/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from zeta.structs import (
AutoregressiveWrapper,
Decoder,
Encoder,
Transformer,
ViTransformerWrapper,
)


Expand Down Expand Up @@ -49,11 +47,6 @@ class GPT4(torch.nn.Module):

def __init__(
self,
image_size=256,
patch_size=32,
encoder_dim=512,
encoder_depth=6,
encoder_heads=8,
num_tokens=20000,
max_seq_len=1024,
decoder_dim=512,
Expand All @@ -69,16 +62,6 @@ def __init__(
qk_norm=True,
):
super(GPT4, self).__init__()

# vit architecture
self.encoder = ViTransformerWrapper(
image_size=image_size,
patch_size=patch_size,
attn_layers=Encoder(
dim=encoder_dim, depth=encoder_depth, heads=encoder_heads
),
)

# palm model architecture
self.decoder = Transformer(
num_tokens=num_tokens,
Expand All @@ -88,7 +71,6 @@ def __init__(
dim=decoder_dim,
depth=decoder_depth,
heads=decoder_heads,
cross_attend=cross_attend,
alibi_pos_bias=alibi_pos_bias,
alibi_num_heads=alibi_num_heads,
rotary_xpos=rotary_xpos,
Expand All @@ -101,21 +83,19 @@ def __init__(
# autoregressive wrapper to enable generation of tokens
self.decoder = AutoregressiveWrapper(self.decoder)

def forward(self, img: torch.Tensor, text: torch.Tensor):
def forward(self, text: torch.Tensor):
"""Forward pass of the model."""
try:
encoded = self.encoder(img, return_embeddings=True)
return self.decoder(text, context=encoded)
return self.decoder(text)
except Exception as error:
print(f"Failed in forward method: {error}")
raise


# Usage with random inputs
img = torch.randn(1, 3, 256, 256)
text = torch.randint(0, 20000, (1, 1024))

# Initiliaze the model
model = GPT4()
output = model(img, text)
output = model(text)
print(output)
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.6"
]
packages = [
{ include = "process_supervision" },
{ include = "process_supervision/**/*.py" },
]


[tool.poetry.dependencies]
python = "^3.8.1"
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
transformers
zetascale
torch
torchvision

0 comments on commit 05b4143

Please sign in to comment.