Skip to content

Commit

Permalink
pnnx auto inputshape from traced inputs (#5825)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Dec 18, 2024
1 parent a12baae commit 4b0d2de
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 11 deletions.
195 changes: 192 additions & 3 deletions tools/pnnx/src/load_torchscript.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <torch/script.h>
#include <torch/csrc/api/include/torch/version.h>
#include <torch/csrc/jit/serialization/import_read.h>
#ifdef PNNX_TORCHVISION
namespace vision {
int64_t cuda_version();
Expand Down Expand Up @@ -421,6 +422,26 @@ static c10::ScalarType input_type_to_c10_ScalarType(const std::string& t)
return torch::kFloat32;
}

static const char* get_at_tensor_type_str(const at::ScalarType& st)
{
if (st == c10::ScalarType::Float) return "f32";
if (st == c10::ScalarType::Double) return "f64";
if (st == c10::ScalarType::Half) return "f16";
if (st == c10::ScalarType::Int) return "i32";
if (st == c10::ScalarType::Long) return "i64";
if (st == c10::ScalarType::Short) return "i16";
if (st == c10::ScalarType::Char) return "i8";
if (st == c10::ScalarType::Byte) return "u8";
if (st == c10::ScalarType::ComplexFloat) return "c64";
if (st == c10::ScalarType::ComplexDouble) return "c128";
if (st == c10::ScalarType::ComplexHalf) return "c32";
if (st == c10::ScalarType::BFloat16) return "bf16";

// unknown
fprintf(stderr, "unsupported tensor elem data type %d\n", (int)st);
return "";
}

const torch::jit::Node* find_node_by_kind(const std::shared_ptr<torch::jit::Graph>& graph, const std::string& kind)
{
for (const auto& n : graph->nodes())
Expand All @@ -432,6 +453,142 @@ const torch::jit::Node* find_node_by_kind(const std::shared_ptr<torch::jit::Grap
return 0;
}

static void print_shape_list(const std::vector<std::vector<int64_t> >& shapes, const std::vector<std::string>& types)
{
for (size_t i = 0; i < shapes.size(); i++)
{
const std::vector<int64_t>& s = shapes[i];
const std::string& t = types[i];
fprintf(stderr, "[");
for (size_t j = 0; j < s.size(); j++)
{
fprintf(stderr, "%ld", s[j]);
if (j != s.size() - 1)
fprintf(stderr, ",");
}
fprintf(stderr, "]");
fprintf(stderr, "%s", t.c_str());
if (i != shapes.size() - 1)
fprintf(stderr, ",");
}
}

static void append_input(std::vector<std::vector<int64_t> >& input_shapes, std::vector<std::string>& input_types, const torch::jit::IValue& v)
{
if (v.isTensor())
{
const auto& tensor = v.toTensor();
input_shapes.push_back(tensor.sizes().vec());
input_types.push_back(get_at_tensor_type_str(tensor.scalar_type()));
}
else if (v.isList())
{
for (const auto& v2 : v.toList())
append_input(input_shapes, input_types, v2);
}
else if (v.isTuple())
{
for (const auto& v2 : v.toTuple()->elements())
append_input(input_shapes, input_types, v2);
}
else if (v.isGenericDict())
{
for (const auto& kv2 : v.toGenericDict())
append_input(input_shapes, input_types, kv2.value());
}
else
{
fprintf(stderr, "unsupported traced input type %s\n", v.tagKind().c_str());
}
}

static void get_traced_input_shape(const std::string& ptpath, std::vector<std::vector<int64_t> >& input_shapes, std::vector<std::string>& input_types)
{
try
{
// read traced_inputs.pkl
caffe2::serialize::PyTorchStreamReader reader(ptpath);
auto v = torch::jit::readArchiveAndTensors("traced_inputs", "", "traced_inputs/", std::nullopt, std::nullopt, std::nullopt, reader);

if (!v.isGenericDict())
return;

for (const auto& entry : v.toGenericDict())
{
if (entry.key() != "forward")
continue;

append_input(input_shapes, input_types, entry.value());
break;
}
}
catch (...)
{
// no traced_inputs.pkl pass
}
}

static bool check_input_shape(const std::vector<std::vector<int64_t> >& traced_input_shapes, const std::vector<std::string>& traced_input_types, const std::vector<std::vector<int64_t> >& input_shapes, const std::vector<std::string>& input_types)
{
if (input_shapes.size() != traced_input_shapes.size())
{
fprintf(stderr, "input_shape expect %d tensors but got %d\n", (int)traced_input_shapes.size(), (int)input_shapes.size());
return false;
}

for (size_t i = 0; i < traced_input_shapes.size(); i++)
{
bool matched = true;

if (input_shapes[i].size() != traced_input_shapes[i].size())
{
matched = false;
}
else
{
for (size_t j = 0; j < traced_input_shapes[i].size(); j++)
{
if (input_shapes[i][j] != traced_input_shapes[i][j])
matched = false;
}
}

if (input_types[i] != traced_input_types[i])
matched = false;

if (!matched)
{
fprintf(stderr, "input_shapes[%d] expect [", (int)i);
for (size_t j = 0; j < traced_input_shapes[i].size(); j++)
{
fprintf(stderr, "%ld", traced_input_shapes[i][j]);
if (j + 1 != traced_input_shapes[i].size())
fprintf(stderr, ",");
}
fprintf(stderr, "]%s but got ", traced_input_types[i].c_str());
if (input_shapes.empty())
{
fprintf(stderr, "nothing\n");
}
else
{
fprintf(stderr, "[");
for (size_t j = 0; j < input_shapes[i].size(); j++)
{
fprintf(stderr, "%ld", input_shapes[i][j]);
if (j + 1 != input_shapes[i].size())
fprintf(stderr, ",");
}
fprintf(stderr, "]%s\n", input_types[i].c_str());
}

return false;
}
}

return true;
}

int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
const std::string& device,
const std::vector<std::vector<int64_t> >& input_shapes,
Expand All @@ -443,6 +600,38 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
const std::string& foldable_constants_zippath,
std::set<std::string>& foldable_constants)
{
// get input shape from traced torchscript
std::vector<std::vector<int64_t> > traced_input_shapes;
std::vector<std::string> traced_input_types;
get_traced_input_shape(ptpath, traced_input_shapes, traced_input_types);

if (!traced_input_shapes.empty())
{
fprintf(stderr, "get inputshape from traced inputs\n");
fprintf(stderr, "inputshape = ");
print_shape_list(traced_input_shapes, traced_input_types);
fprintf(stderr, "\n");

if (!input_shapes.empty())
{
// input shape sanity check
if (!check_input_shape(traced_input_shapes, traced_input_types, input_shapes, input_types))
{
return -1;
}
}
// traced torchscript always has static input shapes
// if (!input_shapes2.empty() && !check_input_shape(ptpath, input_shapes2, input_types2))
// {
// return -1;
// }
}
else
{
traced_input_shapes = input_shapes;
traced_input_types = input_types;
}

#ifdef PNNX_TORCHVISION
// call some vision api to register vision ops :P
(void)vision::cuda_version();
Expand All @@ -467,10 +656,10 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
}

std::vector<at::Tensor> input_tensors;
for (size_t i = 0; i < input_shapes.size(); i++)
for (size_t i = 0; i < traced_input_shapes.size(); i++)
{
const std::vector<int64_t>& shape = input_shapes[i];
const std::string& type = input_types[i];
const std::vector<int64_t>& shape = traced_input_shapes[i];
const std::string& type = traced_input_types[i];

at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
Expand Down
5 changes: 4 additions & 1 deletion tools/pnnx/tests/test_convnext_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def test():

# torchscript to pnnx
import os
os.system("../src/pnnx test_convnext_tiny.pt inputshape=[1,3,224,224]")
if version.parse(torch.__version__) >= version.parse('2.0'):
os.system("../src/pnnx test_convnext_tiny.pt")
else:
os.system("../src/pnnx test_convnext_tiny.pt inputshape=[1,3,224,224]")

# pnnx inference
import test_convnext_tiny_pnnx
Expand Down
6 changes: 5 additions & 1 deletion tools/pnnx/tests/test_mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torchvision.models as models
from packaging import version

def test():
net = models.mobilenet_v2()
Expand All @@ -30,7 +31,10 @@ def test():

# torchscript to pnnx
import os
os.system("../src/pnnx test_mobilenet_v2.pt inputshape=[1,3,224,224]")
if version.parse(torch.__version__) >= version.parse('2.0'):
os.system("../src/pnnx test_mobilenet_v2.pt")
else:
os.system("../src/pnnx test_mobilenet_v2.pt inputshape=[1,3,224,224]")

# pnnx inference
import test_mobilenet_v2_pnnx
Expand Down
6 changes: 5 additions & 1 deletion tools/pnnx/tests/test_mobilenet_v3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torchvision.models as models
from packaging import version

def test():
net = models.mobilenet_v3_small()
Expand All @@ -30,7 +31,10 @@ def test():

# torchscript to pnnx
import os
os.system("../src/pnnx test_mobilenet_v3_small.pt inputshape=[1,3,224,224]")
if version.parse(torch.__version__) >= version.parse('2.0'):
os.system("../src/pnnx test_mobilenet_v3_small.pt")
else:
os.system("../src/pnnx test_mobilenet_v3_small.pt inputshape=[1,3,224,224]")

# pnnx inference
import test_mobilenet_v3_small_pnnx
Expand Down
6 changes: 5 additions & 1 deletion tools/pnnx/tests/test_resnet18.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torchvision.models as models
from packaging import version

def test():
net = models.resnet18()
Expand All @@ -30,7 +31,10 @@ def test():

# torchscript to pnnx
import os
os.system("../src/pnnx test_resnet18.pt inputshape=[1,3,224,224]")
if version.parse(torch.__version__) >= version.parse('2.0'):
os.system("../src/pnnx test_resnet18.pt")
else:
os.system("../src/pnnx test_resnet18.pt inputshape=[1,3,224,224]")

# pnnx inference
import test_resnet18_pnnx
Expand Down
6 changes: 5 additions & 1 deletion tools/pnnx/tests/test_shufflenet_v2_x1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torchvision.models as models
from packaging import version

def test():
net = models.shufflenet_v2_x1_0()
Expand All @@ -30,7 +31,10 @@ def test():

# torchscript to pnnx
import os
os.system("../src/pnnx test_shufflenet_v2_x1_0.pt inputshape=[1,3,224,224]")
if version.parse(torch.__version__) >= version.parse('2.0'):
os.system("../src/pnnx test_shufflenet_v2_x1_0.pt")
else:
os.system("../src/pnnx test_shufflenet_v2_x1_0.pt inputshape=[1,3,224,224]")

# pnnx inference
import test_shufflenet_v2_x1_0_pnnx
Expand Down
6 changes: 5 additions & 1 deletion tools/pnnx/tests/test_squeezenet1_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torchvision.models as models
from packaging import version

def test():
net = models.squeezenet1_1()
Expand All @@ -30,7 +31,10 @@ def test():

# torchscript to pnnx
import os
os.system("../src/pnnx test_squeezenet1_1.pt inputshape=[1,3,224,224]")
if version.parse(torch.__version__) >= version.parse('2.0'):
os.system("../src/pnnx test_squeezenet1_1.pt")
else:
os.system("../src/pnnx test_squeezenet1_1.pt inputshape=[1,3,224,224]")

# pnnx inference
import test_squeezenet1_1_pnnx
Expand Down
5 changes: 4 additions & 1 deletion tools/pnnx/tests/test_swin_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def test():

# torchscript to pnnx
import os
os.system("../src/pnnx test_swin_t.pt inputshape=[1,3,224,224]")
if version.parse(torch.__version__) >= version.parse('2.0'):
os.system("../src/pnnx test_swin_t.pt")
else:
os.system("../src/pnnx test_swin_t.pt inputshape=[1,3,224,224]")

# pnnx inference
import test_swin_t_pnnx
Expand Down
5 changes: 4 additions & 1 deletion tools/pnnx/tests/test_vit_b_32.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def test():

# torchscript to pnnx
import os
os.system("../src/pnnx test_vit_b_32.pt inputshape=[1,3,224,224]")
if version.parse(torch.__version__) >= version.parse('2.0'):
os.system("../src/pnnx test_vit_b_32.pt")
else:
os.system("../src/pnnx test_vit_b_32.pt inputshape=[1,3,224,224]")

# pnnx inference
import test_vit_b_32_pnnx
Expand Down

0 comments on commit 4b0d2de

Please sign in to comment.