Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to conversion a convolution when padding argument is 'valid' or default #31

Closed
kokeshing opened this issue Apr 6, 2024 · 2 comments

Comments

@kokeshing
Copy link
Contributor

kokeshing commented Apr 6, 2024

First of all, thank you for developing this great project!

I found an event that failed in the convolution conversion when its padding argument is 'valid'(nn.Conv) or default(F.conv).

UserWarning: Conversion exception on node 'Conv1d': The `padding` argument must be a tuple of 2 integers. Received: v
    raise Exception(f'Failed conversion: {self.original_node}')
Exception: Failed conversion: Conv1d(128, 128, kernel_size=(1,), stride=(1,), padding=valid)

UserWarning: Validation exception on node 'conv1d': Failed conversion: <built-in method conv1d of type object at 0x105101780>
Exception: Failed conversion: <built-in method conv1d of type object at 0x122109780>

ValueError: `padding` should have two elements. Received: valid.
    raise Exception(f'Failed conversion: {self.original_node}')
Exception: Failed conversion: Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), padding=valid)

UserWarning: Validation exception on node 'ErrorModel': Failed conversion: <built-in method conv2d of type object at 0x122101980>
Exception: Failed conversion: <built-in method conv2d of type object at 0x122101510>

The code to be reproduced is as follows:

import nobuco
import torch
import torch.nn as nn
import torch.nn.functional as F
from nobuco import ChannelOrder


class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv0_weight = nn.Parameter(torch.randn(128, 3, 3, 3))
        self.conv0_bias = nn.Parameter(torch.randn(128))

        self.conv1 = nn.Conv2d(128, 128, 1, 1, padding=0)

        self.conv2_weight = nn.Parameter(torch.randn(128, 128, 3))
        self.conv2_bias = nn.Parameter(torch.randn(128))

        self.conv3 = nn.Conv1d(128, 128, 1, 1, padding=0)

    def forward(self, x):
        x = F.conv2d(x, self.conv0_weight, self.conv0_bias, padding="same")
        x = F.relu(x)
        x = self.conv1(x)
        x = F.relu(x)

        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = F.conv1d(x, self.conv2_weight, self.conv2_bias, padding="same")
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)

        return x


class ErrorModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv0_weight = nn.Parameter(torch.randn(128, 3, 3, 3))
        self.conv0_bias = nn.Parameter(torch.randn(128))

        self.conv1 = nn.Conv2d(128, 128, 1, 1, "valid")

        self.conv2_weight = nn.Parameter(torch.randn(128, 128, 3))
        self.conv2_bias = nn.Parameter(torch.randn(128))

        self.conv3 = nn.Conv1d(128, 128, 1, 1, "valid")

    def forward(self, x):
        x = F.conv2d(x, self.conv0_weight, self.conv0_bias)
        x = F.relu(x)
        x = self.conv1(x)
        x = F.relu(x)

        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = F.conv1d(x, self.conv2_weight, self.conv2_bias)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)

        return x


def main():
    dummy_image = torch.rand(size=(1, 3, 64, 64))

    model = Model().eval()
    _ = nobuco.pytorch_to_keras(
        model,
        args=[dummy_image],
        kwargs=None,
        inputs_channel_order=ChannelOrder.TENSORFLOW,
        outputs_channel_order=ChannelOrder.TENSORFLOW,
    )

    error_model = ErrorModel().eval()
    _ = nobuco.pytorch_to_keras(
        error_model,
        args=[dummy_image],
        kwargs=None,
        inputs_channel_order=ChannelOrder.TENSORFLOW,
        outputs_channel_order=ChannelOrder.TENSORFLOW,
    )  # This will raise an error


if __name__ == "__main__":
    main()
@kokeshing
Copy link
Contributor Author

I have investigated and submitted #32. Please confirm.

@AlexanderLutsenko
Copy link
Owner

Awesome, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants