Skip to content

Commit

Permalink
Codegen Non-Native IR Nodes (pytorch#76535)
Browse files Browse the repository at this point in the history
Add codegen infrastructure to generate IR nodes for non-native ops.

The proposed change is to add a `non_native` key to the `{backend}_native_functions.yaml` file that contains schema definitions similar to what is found in `native_functions.yaml`. e.g.
```
non_native:
    ...
    - func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
    ...
```
these definitions are parsed into a `LazyIrSchema` that can be used for generating IR nodes using `GenLazyIR`.

Fixes pytorch#74628

CC: @wconstab @desertfire @henrytwo

Pull Request resolved: pytorch#76535
Approved by: https://github.com/wconstab
  • Loading branch information
antoniojkim authored and pytorchmergebot committed May 24, 2022
1 parent 13dcba8 commit 02c4d87
Show file tree
Hide file tree
Showing 55 changed files with 497 additions and 1,348 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,7 @@ test_suite(
"aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
"aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
"aten/src/ATen/templates/LazyIr.h",
"aten/src/ATen/templates/LazyNonNativeIr.h",
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
Expand Down
38 changes: 38 additions & 0 deletions aten/src/ATen/native/ts_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,41 @@ supported:
- _unsafe_view
autograd:
- max_pool3d

# Ops that don't have a native schema definitions and are dispatched within Lazy Tensor Core
non_native:
- func: scalar(Scalar value, ScalarType type) -> Tensor
opkind: at::prim::Constant
properties:
- ShapeCompute
- TreatScalarsAsConstants
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
- func: view(Tensor input, int[] output_size) -> Tensor
properties:
- ShapeCompute
- func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor
opkind: ltc_cast
properties:
- ShapeCompute

# View ops only required until proper functionalization pass is introduced into LTC
- func: as_strided_view_update(Tensor target, Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
opkind: ltc_as_strided_view_update
- func: as_strided(Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
- func: diagonal_view_update(Tensor target, Tensor input, int offset, int dim1, int dim2) -> Tensor
opkind: ltc_diagonal_view_update
properties:
- ShapeCompute
- func: diagonal(Tensor input, int offset, int dim1, int dim2) -> Tensor
- func: narrow_view_update(Tensor input, Tensor source, int[] base_indices) -> Tensor
opkind: ltc_narrow_view_update
- func: narrow(Tensor input, int[] base_indices, int[] sizes) -> Tensor
- func: permute(Tensor input, int[] dims) -> Tensor
- func: resize(Tensor input, int[] size) -> Tensor
- func: select_view_update(Tensor target, Tensor source, int dim, int start, int end, int stride) -> Tensor
opkind: ltc_select_view_update
properties:
- ShapeCompute
- func: select(Tensor input, int dim, int start, int end, int stride) -> Tensor
- func: squeeze(Tensor input, int dim) -> Tensor
- func: unsqueeze(Tensor input, int dim) -> Tensor
11 changes: 11 additions & 0 deletions aten/src/ATen/templates/LazyNonNativeIr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

${lazy_non_native_ir_inc}

// This file contains autogenerated LazyTensor Non Native IR nodes

${namespace_prologue}

${non_native_ir_nodes}

${namespace_epilogue}
2 changes: 2 additions & 0 deletions build.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def define_targets(rules):
":DispatchKeyNativeFunctions.cpp",
":DispatchKeyNativeFunctions.h",
":LazyIr.h",
":LazyNonNativeIr.h",
":RegisterDispatchKey.cpp",
":native_functions.yaml",
":shape_inference.h",
Expand Down Expand Up @@ -88,6 +89,7 @@ GENERATED_TESTING_PY = [

GENERATED_LAZY_H = [
"torch/csrc/lazy/generated/LazyIr.h",
"torch/csrc/lazy/generated/LazyNonNativeIr.h",
"torch/csrc/lazy/generated/LazyNativeFunctions.h",
]

Expand Down
2 changes: 2 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
list(APPEND GENERATED_H_TORCH
"${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType.h"
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyIr.h"
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNonNativeIr.h"
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNativeFunctions.h"
)
endif()
Expand Down Expand Up @@ -444,6 +445,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h"
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
"${TOOLS_PATH}/autograd/templates/VariableType.h"
"${TOOLS_PATH}/autograd/templates/VariableType.cpp"
Expand Down
22 changes: 3 additions & 19 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -417,35 +417,19 @@ lazy_tensor_core_sources = [
# We can't build all of the ts backend under certain build configurations, e.g. mobile,
# since it depends on things like autograd, meta functions, which may be disabled
lazy_tensor_ts_sources = [
"torch/csrc/lazy/ts_backend/config.cpp",
"torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
"torch/csrc/lazy/ts_backend/config.cpp",
"torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp",
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
"torch/csrc/lazy/ts_backend/ops/cast.cpp",
"torch/csrc/lazy/ts_backend/ops/device_data.cpp",
"torch/csrc/lazy/ts_backend/ops/expand.cpp",
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
"torch/csrc/lazy/ts_backend/ops/scalar.cpp",
"torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp",
"torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp",
"torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp",
"torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.cpp",
"torch/csrc/lazy/ts_backend/view_ops/narrow.cpp",
"torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.cpp",
"torch/csrc/lazy/ts_backend/view_ops/permute.cpp",
"torch/csrc/lazy/ts_backend/view_ops/resize.cpp",
"torch/csrc/lazy/ts_backend/view_ops/select.cpp",
"torch/csrc/lazy/ts_backend/view_ops/squeeze.cpp",
"torch/csrc/lazy/ts_backend/view_ops/unsqueeze.cpp",
"torch/csrc/lazy/ts_backend/view_ops/select_view_update.cpp",
"torch/csrc/lazy/ts_backend/view_ops/view.cpp",
"torch/csrc/lazy/ts_backend/ts_node.cpp",
"torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp",
"torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp",
"torch/csrc/lazy/ts_backend/ts_backend_impl.cpp",
"torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp",
"torch/csrc/lazy/ts_backend/ts_lowering_context.cpp",
"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
"torch/csrc/lazy/ts_backend/ts_node.cpp",
"torch/csrc/lazy/ts_backend/ts_node_lowering.cpp",
]

Expand Down
2 changes: 1 addition & 1 deletion tools/test/test_gen_backend_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_unrecognized_key(self) -> None:
output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
self.assertExpectedInline(
output_error,
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen""", # noqa: B950
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native""", # noqa: B950
)

# if use_out_as_primary is provided, it must be a bool
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/lazy/core/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ hash_t Output::hash() const {
return HashCombine(node->hash(), Hash(index));
}

hash_t Output::shapeHash() const {
return HashCombine(node->shapeHash(), Hash(index));
}

std::string Output::ToString() const {
std::stringstream ss;
ss << node->ToString() << ", index=" << index;
Expand Down Expand Up @@ -144,7 +148,7 @@ std::string Node::ToString() const {

void Node::AddOperand(NodePtr node, size_t index) {
CHECK_LT(index, node->num_outputs());
operands_.push_back(std::move(node));
operands_.push_back(node);
operands_as_outputs_.emplace_back(operands_.back().get(), index);
}

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/lazy/core/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ struct TORCH_API Output {
: node(node), index(index) {}

hash_t hash() const;
hash_t shapeHash() const;

bool operator==(const Output& rhs) const {
return node == rhs.node && index == rhs.index;
Expand Down
16 changes: 8 additions & 8 deletions torch/csrc/lazy/core/ops/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,33 @@
namespace torch {
namespace lazy {

bool StrideIsSupported(c10::ArrayRef<int64_t> stride);
TORCH_API bool StrideIsSupported(c10::ArrayRef<int64_t> stride);

std::vector<int64_t> GetArrayStridePermutation(c10::ArrayRef<int64_t> stride);
TORCH_API std::vector<int64_t> GetArrayStridePermutation(c10::ArrayRef<int64_t> stride);

Shape MakeDiagonalShape(
TORCH_API Shape MakeDiagonalShape(
const Shape& shape,
int64_t offset,
int64_t dim1,
int64_t dim2);

Shape MakePermuteShape(
TORCH_API Shape MakePermuteShape(
const Shape& source_shape,
c10::ArrayRef<int64_t> permutation);

Shape MakeSelectShape(
TORCH_API Shape MakeSelectShape(
const Shape& shape,
int64_t dim,
int64_t start,
int64_t end,
int64_t stride);

int64_t GetStride(int64_t start, int64_t end, int64_t stride);
TORCH_API int64_t GetStride(int64_t start, int64_t end, int64_t stride);

std::vector<int64_t> BuildSqueezedDimensions(c10::ArrayRef<int64_t> dimensions,
TORCH_API std::vector<int64_t> BuildSqueezedDimensions(c10::ArrayRef<int64_t> dimensions,
int64_t squeeze_dim);

std::vector<int64_t> BuildUnsqueezedDimensions(
TORCH_API std::vector<int64_t> BuildUnsqueezedDimensions(
c10::ArrayRef<int64_t> dimensions,
int64_t squeeze_dim);

Expand Down
7 changes: 5 additions & 2 deletions torch/csrc/lazy/core/permutation_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ std::vector<typename Container::value_type> PermuteDimensions(
const Container& dimensions) {
using T = typename Container::value_type;
TORCH_CHECK(
dimensions.size() == permutation.size() && IsPermutation(permutation),
"Invalid permutation specified");
dimensions.size() == permutation.size(),
"Invalid permutation specified. dimensions.size() != permutation.size() (", dimensions.size(), " vs. ", permutation.size(), ")");
TORCH_CHECK(
IsPermutation(permutation),
"Invalid permutation specified. Permutation is not permutation");
std::vector<T> output(dimensions.size());
for (const auto i : c10::irange(permutation.size())) {
output[i] = dimensions[permutation[i]];
Expand Down
63 changes: 63 additions & 0 deletions torch/csrc/lazy/core/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@

#include <torch/csrc/lazy/core/shape_inference.h>

#include <torch/csrc/lazy/core/ops/utils.h>
#include <torch/csrc/lazy/core/shape.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/InferSize.h>
#include <ATen/WrapDimUtils.h>
#include <aten/src/ATen/native/ReduceOpsUtils.h>
#include <c10/core/ScalarType.h>
Expand Down Expand Up @@ -629,6 +631,67 @@ std::vector<Shape> compute_shape_narrow_copy(const at::Tensor & self, int64_t di
return {Shape(self.scalar_type(), self.sizes().vec())};
}


// Non-Native Ops
std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type) {
return { Shape(type, {}) };
}
std::vector<Shape> compute_shape_expand(const Output& input, const std::vector<int64_t>& size, const bool& is_scalar_expand) {
return { Shape(input.shape().scalar_type(), size) };
}
std::vector<Shape> compute_shape_view(const Output& input, const std::vector<int64_t>& output_sizes) {
const Shape& input_shape = input.shape();
const auto complete_output_sizes =
at::infer_size(output_sizes, input_shape.numel());
return { Shape(input_shape.scalar_type(), complete_output_sizes) };
}
std::vector<Shape> compute_shape_cast(const Output& input, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype) {
Shape shape = input.shape();
shape.set_scalar_type(dtype);
return { shape };
}


// View Ops
std::vector<Shape> compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset) {
return { Shape(target.shape().scalar_type(), size) };
}
std::vector<Shape> compute_shape_as_strided(const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset) {
return { Shape(input.shape().scalar_type(), size) };
}
std::vector<Shape> compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) {
return { target.shape() };
}
std::vector<Shape> compute_shape_diagonal(const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2) {
return { MakeDiagonalShape(input.shape(), offset, dim1, dim2) };
}
std::vector<Shape> compute_shape_narrow_view_update(const Output& input, const Output& source, const std::vector<int64_t>& base_indices) {
return { input.shape() };
}
std::vector<Shape> compute_shape_narrow(const Output& input, const std::vector<int64_t>& base_indices, const std::vector<int64_t>& sizes) {
return { Shape(input.shape().scalar_type(), sizes) };
}
std::vector<Shape> compute_shape_permute(const Output& input, const std::vector<int64_t>& dims) {
return { MakePermuteShape(input.shape(), dims) };
}
std::vector<Shape> compute_shape_resize(const Output& input, const std::vector<int64_t>& size) {
return { Shape(input.shape().scalar_type(), size) };
}
std::vector<Shape> compute_shape_select_view_update(const Output& target, const Output& source, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) {
return { target.shape() };
}
std::vector<Shape> compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride) {
return { MakeSelectShape(input.shape(), dim, start, end, stride) };
}
std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim) {
const auto& input_shape = input.shape();
return { torch::lazy::Shape(input_shape.scalar_type(), BuildSqueezedDimensions(input_shape.sizes(), dim)) };
}
std::vector<Shape> compute_shape_unsqueeze(const Output& input, const int& dim) {
const auto& input_shape = input.shape();
return { torch::lazy::Shape(input_shape.scalar_type(), BuildUnsqueezedDimensions(input_shape.sizes(), dim)) };
}

// Restore unused-parameters warnings
#pragma GCC diagnostic pop

Expand Down
22 changes: 22 additions & 0 deletions torch/csrc/lazy/core/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <c10/core/ScalarType.h>
#include <c10/macros/Export.h>
#include <c10/util/Optional.h>
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/shape.h>
#include <vector>
Expand Down Expand Up @@ -68,5 +69,26 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape__to_copy(const at::Tenso
TORCH_API std::vector<torch::lazy::Shape> compute_shape_trace(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_zero_functional(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy(const at::Tensor & self, int64_t dim, int64_t start, c10::SymInt length);

// Non-Native ops
TORCH_API std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type);
TORCH_API std::vector<Shape> compute_shape_expand(const Output& input0, const std::vector<int64_t>& size, const bool& is_scalar_expand);
TORCH_API std::vector<Shape> compute_shape_view(const Output& input0, const std::vector<int64_t>& output_sizes);
TORCH_API std::vector<Shape> compute_shape_cast(const Output& input0, const at::ScalarType& dtype, const c10::optional<at::ScalarType>& stype);

// View Ops
TORCH_API std::vector<Shape> compute_shape_as_strided_view_update(const Output& target, const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
TORCH_API std::vector<Shape> compute_shape_as_strided(const Output& input, const std::vector<int64_t>& size, const std::vector<int64_t>& stride, const int64_t& storage_offset);
TORCH_API std::vector<Shape> compute_shape_diagonal_view_update(const Output& target, const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2);
TORCH_API std::vector<Shape> compute_shape_diagonal(const Output& input, const int64_t& offset, const int64_t& dim1, const int64_t& dim2);
TORCH_API std::vector<Shape> compute_shape_narrow_view_update(const Output& input, const Output& source, const std::vector<int64_t>& base_indices);
TORCH_API std::vector<Shape> compute_shape_narrow(const Output& input, const std::vector<int64_t>& base_indices, const std::vector<int64_t>& sizes);
TORCH_API std::vector<Shape> compute_shape_permute(const Output& input, const std::vector<int64_t>& dims);
TORCH_API std::vector<Shape> compute_shape_resize(const Output& input, const std::vector<int64_t>& size);
TORCH_API std::vector<Shape> compute_shape_select_view_update(const Output& target, const Output& source, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride);
TORCH_API std::vector<Shape> compute_shape_select(const Output& input, const int64_t& dim, const int64_t& start, const int64_t& end, const int64_t& stride);
TORCH_API std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim);
TORCH_API std::vector<Shape> compute_shape_unsqueeze(const Output& input, const int& dim);

} // namespace lazy
} // namespace torch
24 changes: 2 additions & 22 deletions torch/csrc/lazy/ts_backend/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,11 @@
#include <torch/csrc/lazy/core/ir_builder.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/shape_inference.h>
#include <torch/csrc/lazy/generated/LazyNonNativeIr.h>
#include <torch/csrc/lazy/ts_backend/ts_node.h>
#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
#include <torch/csrc/lazy/ts_backend/view_ops/narrow.h>
#include <torch/csrc/lazy/ts_backend/view_ops/select_view_update.h>
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.h>
#include <torch/csrc/lazy/ts_backend/view_ops/permute.h>
#include <torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.h>
#include <torch/csrc/lazy/ts_backend/view_ops/resize.h>
#include <torch/csrc/lazy/ts_backend/view_ops/squeeze.h>
#include <torch/csrc/lazy/ts_backend/view_ops/diagonal.h>
#include <torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.h>
#include <torch/csrc/lazy/ts_backend/view_ops/as_strided.h>
#include <torch/csrc/lazy/ts_backend/view_ops/unsqueeze.h>
#include <torch/csrc/lazy/ts_backend/view_ops/select.h>
#include <torch/csrc/lazy/ts_backend/view_ops/view.h>
#include <torch/csrc/lazy/ts_backend/ops/cast.h>
#include <torch/csrc/lazy/ts_backend/ops/device_data.h>
#include <torch/csrc/lazy/ts_backend/ops/generic.h>
#include <torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h>
#include <torch/csrc/lazy/ts_backend/ops/to_copy.h>
#include <torch/csrc/lazy/ts_backend/ops/scalar.h>
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
#include <torch/csrc/lazy/ts_backend/ops/expand.h>

// This file contains the TorchScript IrBuilder
#include <torch/csrc/lazy/ts_backend/ops/device_data.h>

namespace torch {
namespace lazy {
Expand Down
Loading

0 comments on commit 02c4d87

Please sign in to comment.