Skip to content

Commit

Permalink
Add PINN example
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Jan 13, 2023
1 parent 08c8279 commit 9a1f965
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 20 deletions.
2 changes: 1 addition & 1 deletion benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
TaylorSeries = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
57 changes: 41 additions & 16 deletions benchmark/pinn.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,51 @@
using TaylorDiff, Zygote
using Flux
using ChainRulesCore: @opt_out
using TaylorDiff
using Zygote
using Plots

const input = 2
const hidden = 16

struct PINN
W₁
b₁
W₂
b₂
model = Chain(
Dense(input => hidden, sin),
Dense(hidden => hidden, sin),
Dense(hidden => 1),
first
)
trial(model, x) = model(x)

ε = cbrt(eps(Float32))
ε₁ = [ε, 0]
ε₂ = [0, ε]

M = 100
data = [rand(input) for _ in 1:M]
function loss_by_finitediff(model, x)
error = (trial(model, x + ε₁) + trial(model, x - ε₁) + trial(model, x + ε₂) +
trial(model, x - ε₂) - 4 * trial(model, x)) /
ε^2 + sin* x[1]) * sin* x[2])
abs2(error)
end
function loss_by_taylordiff(model, x)
f(x) = trial(model, x)
error = derivative(f, x, [1., 0.], 2) + derivative(f, x, [0., 1.], 2) + sin* x[1]) * sin* x[2])
abs2(error)
end

(pinn::PINN)(x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * first(pinn.W₂ * exp.(pinn.W₁ * x + pinn.b₁) + pinn.b₂)
opt = Flux.setup(Adam(), model)

dataset = [rand(input) for i in 1:10]
function loss(pinn)
out = 0.0
for x in dataset
out += derivative(pinn, x, [1., 0.], Val(2))
end
out
allloss(model, loss) = sum([loss(model, x) for x in data])
for epoch in 1:1000
Flux.train!(loss_by_taylordiff, model, data, opt)
end

myPINN = PINN(rand(hidden, input), rand(hidden), rand(1, hidden), rand(1))
grid = 0:0.01:1
solution(x, y) = (sin* x) * sin* y)) / (2π^2)
u = [trial(model, [x, y]) for x in grid, y in grid]
utrue = [solution(x, y) for x in grid, y in grid]
diff_u = abs.(u .- utrue)

gradient(loss, myPINN)
surface(u)
surface(utrue)
surface(diff_u)
12 changes: 9 additions & 3 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import ChainRulesCore: rrule, RuleConfig
import ChainRulesCore: rrule, RuleConfig, ProjectTo
using ZygoteRules: @adjoint

contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N} = mapreduce(*, +, value(a), value(b))

NONLINEAR_UNARY_FUNCTIONS = Function[
exp, exp2, exp10, expm1,
log, log2, log10, log1p,
inv, sqrt, cbrt,
sin, cos, tan, cot, sec, csc,
asin, acos, atan, acot, asec, acsc,
sinh, cosh, tanh, coth, sech, csch,
Expand Down Expand Up @@ -52,8 +53,9 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
return extract_derivative(t, i), extract_derivative_pullback
end

function rrule(::typeof(*), A::Matrix{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
gemv_pullback(x̄) = NoTangent(), contract.(x̄, transpose(t)), transpose(A) *
function rrule(::typeof(*), A::Matrix{S}, t::Vector{TaylorScalar{T, N}}) where {N, S <: Number, T}
project_A = ProjectTo(A)
gemv_pullback(x̄) = NoTangent(), project_A(contract.(x̄, transpose(t))), transpose(A) *
return A * t, gemv_pullback
end

Expand All @@ -70,3 +72,7 @@ end
@adjoint +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number} = t + v, x̄ -> (x̄, map(primal, x̄))

@adjoint +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number} = v + t, x̄ -> (map(primal, x̄), x̄)

(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)

(project::ProjectTo{S})(dx::TaylorScalar{T, N}) where {N, T <: Number, S <: Real} = project(primal(dx))
1 change: 1 addition & 0 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import Base: hypot, max, min
@inline sqrt(t::TaylorScalar) = t^0.5
@inline cbrt(t::TaylorScalar) = ^(t, 1 / 3)
@inline inv(t::TaylorScalar) = 1 / t
@inline abs(t::TaylorScalar) = primal(t) >= 0 ? t : -t

for func in (:exp, :expm1, :exp2, :exp10)
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
Expand Down

0 comments on commit 9a1f965

Please sign in to comment.