Skip to content

Commit

Permalink
Merge pull request JuliaLang#47667 from JuliaLang/jn/47476
Browse files Browse the repository at this point in the history
ensure proper handling of sparams for widened compile signatures

Fix JuliaLang#47476
  • Loading branch information
vtjnash authored Dec 12, 2022
2 parents 4ad6aef + 9e5e28f commit 770f5a3
Show file tree
Hide file tree
Showing 13 changed files with 301 additions and 124 deletions.
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
return ci
end
if may_discard_trees(interp)
cache_the_tree = ci.inferred && (is_inlineable(ci) || isa_compileable_sig(linfo.specTypes, def))
cache_the_tree = ci.inferred && (is_inlineable(ci) || isa_compileable_sig(linfo.specTypes, linfo.sparam_vals, def))
else
cache_the_tree = true
end
Expand Down
11 changes: 8 additions & 3 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ function get_compileable_sig(method::Method, @nospecialize(atype), sparams::Simp
mt, atype, sparams, method)
end

isa_compileable_sig(@nospecialize(atype), method::Method) =
!iszero(ccall(:jl_isa_compileable_sig, Int32, (Any, Any), atype, method))
isa_compileable_sig(@nospecialize(atype), sparams::SimpleVector, method::Method) =
!iszero(ccall(:jl_isa_compileable_sig, Int32, (Any, Any, Any), atype, sparams, method))

# eliminate UnionAll vars that might be degenerate due to having identical bounds,
# or a concrete upper bound and appearing covariantly.
Expand Down Expand Up @@ -206,7 +206,12 @@ function specialize_method(method::Method, @nospecialize(atype), sparams::Simple
if compilesig
new_atype = get_compileable_sig(method, atype, sparams)
new_atype === nothing && return nothing
atype = new_atype
if atype !== new_atype
sp_ = ccall(:jl_type_intersection_with_env, Any, (Any, Any), new_atype, method.sig)::SimpleVector
if sparams === sp_[2]::SimpleVector
atype = new_atype
end
end
end
if preexisting
# check cached specializations
Expand Down
338 changes: 230 additions & 108 deletions src/gf.c

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ static jl_callptr_t _jl_compile_codeinst(
// hack to export this pointer value to jl_dump_method_disasm
jl_atomic_store_release(&this_code->specptr.fptr, (void*)getAddressForFunction(decls.specFunctionObject));
}
if (this_code== codeinst)
if (this_code == codeinst)
fptr = addr;
}

Expand Down
26 changes: 25 additions & 1 deletion src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -1093,12 +1093,36 @@ jl_value_t *jl_unwrap_unionall(jl_value_t *v)
}

// wrap `t` in the same unionalls that surround `u`
// where `t` is derived from `u`, so the error checks in jl_type_unionall are unnecessary
jl_value_t *jl_rewrap_unionall(jl_value_t *t, jl_value_t *u)
{
if (!jl_is_unionall(u))
return t;
JL_GC_PUSH1(&t);
t = jl_rewrap_unionall(t, ((jl_unionall_t*)u)->body);
jl_tvar_t *v = ((jl_unionall_t*)u)->var;
// normalize `T where T<:S` => S
if (t == (jl_value_t*)v)
return v->ub;
// where var doesn't occur in body just return body
if (!jl_has_typevar(t, v))
return t;
JL_GC_PUSH1(&t);
//if (v->lb == v->ub) // TODO maybe
// t = jl_substitute_var(body, v, v->ub);
//else
t = jl_new_struct(jl_unionall_type, v, t);
JL_GC_POP();
return t;
}

// wrap `t` in the same unionalls that surround `u`
// where `t` is extended from `u`, so the checks in jl_rewrap_unionall are unnecessary
jl_value_t *jl_rewrap_unionall_(jl_value_t *t, jl_value_t *u)
{
if (!jl_is_unionall(u))
return t;
t = jl_rewrap_unionall_(t, ((jl_unionall_t*)u)->body);
JL_GC_PUSH1(&t);
t = jl_new_struct(jl_unionall_type, ((jl_unionall_t*)u)->var, t);
JL_GC_POP();
return t;
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ STATIC_INLINE int jl_is_concrete_type(jl_value_t *v) JL_NOTSAFEPOINT
return jl_is_datatype(v) && ((jl_datatype_t*)v)->isconcretetype;
}

JL_DLLEXPORT int jl_isa_compileable_sig(jl_tupletype_t *type, jl_method_t *definition);
JL_DLLEXPORT int jl_isa_compileable_sig(jl_tupletype_t *type, jl_svec_t *sparams, jl_method_t *definition);

// type constructors
JL_DLLEXPORT jl_typename_t *jl_new_typename_in(jl_sym_t *name, jl_module_t *inmodule, int abstract, int mutabl);
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ JL_DLLEXPORT jl_value_t *jl_instantiate_type_in_env(jl_value_t *ty, jl_unionall_
jl_value_t *jl_substitute_var(jl_value_t *t, jl_tvar_t *var, jl_value_t *val);
JL_DLLEXPORT jl_value_t *jl_unwrap_unionall(jl_value_t *v JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_rewrap_unionall(jl_value_t *t, jl_value_t *u);
JL_DLLEXPORT jl_value_t *jl_rewrap_unionall_(jl_value_t *t, jl_value_t *u);
int jl_count_union_components(jl_value_t *v);
JL_DLLEXPORT jl_value_t *jl_nth_union_component(jl_value_t *v JL_PROPAGATES_ROOT, int i) JL_NOTSAFEPOINT;
int jl_find_union_component(jl_value_t *haystack, jl_value_t *needle, unsigned *nth) JL_NOTSAFEPOINT;
Expand Down
4 changes: 2 additions & 2 deletions src/precompile.c
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ static void jl_compile_all_defs(jl_array_t *mis)
size_t i, l = jl_array_len(allmeths);
for (i = 0; i < l; i++) {
jl_method_t *m = (jl_method_t*)jl_array_ptr_ref(allmeths, i);
if (jl_isa_compileable_sig((jl_tupletype_t*)m->sig, m)) {
if (jl_is_datatype(m->sig) && jl_isa_compileable_sig((jl_tupletype_t*)m->sig, jl_emptysvec, m)) {
// method has a single compilable specialization, e.g. its definition
// signature is concrete. in this case we can just hint it.
jl_compile_hint((jl_tupletype_t*)m->sig);
Expand Down Expand Up @@ -354,7 +354,7 @@ static void *jl_precompile_(jl_array_t *m)
mi = (jl_method_instance_t*)item;
size_t min_world = 0;
size_t max_world = ~(size_t)0;
if (mi != jl_atomic_load_relaxed(&mi->def.method->unspecialized) && !jl_isa_compileable_sig((jl_tupletype_t*)mi->specTypes, mi->def.method))
if (mi != jl_atomic_load_relaxed(&mi->def.method->unspecialized) && !jl_isa_compileable_sig((jl_tupletype_t*)mi->specTypes, mi->sparam_vals, mi->def.method))
mi = jl_get_specialization1((jl_tupletype_t*)mi->specTypes, jl_atomic_load_acquire(&jl_world_counter), &min_world, &max_world, 0);
if (mi)
jl_array_ptr_1d_push(m2, (jl_value_t*)mi);
Expand Down
6 changes: 3 additions & 3 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -2890,8 +2890,8 @@ static jl_value_t *intersect_sub_datatype(jl_datatype_t *xd, jl_datatype_t *yd,
jl_value_t *super_pattern=NULL;
JL_GC_PUSH2(&isuper, &super_pattern);
jl_value_t *wrapper = xd->name->wrapper;
super_pattern = jl_rewrap_unionall((jl_value_t*)((jl_datatype_t*)jl_unwrap_unionall(wrapper))->super,
wrapper);
super_pattern = jl_rewrap_unionall_((jl_value_t*)((jl_datatype_t*)jl_unwrap_unionall(wrapper))->super,
wrapper);
int envsz = jl_subtype_env_size(super_pattern);
jl_value_t *ii = jl_bottom_type;
{
Expand Down Expand Up @@ -3528,7 +3528,7 @@ jl_value_t *jl_type_intersection_env_s(jl_value_t *a, jl_value_t *b, jl_svec_t *
if (jl_is_uniontype(ans_unwrapped)) {
ans_unwrapped = switch_union_tuple(((jl_uniontype_t*)ans_unwrapped)->a, ((jl_uniontype_t*)ans_unwrapped)->b);
if (ans_unwrapped != NULL) {
*ans = jl_rewrap_unionall(ans_unwrapped, *ans);
*ans = jl_rewrap_unionall_(ans_unwrapped, *ans);
}
}
JL_GC_POP();
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Random/src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ rand(rng::AbstractRNG, ::UniformT{T}) where {T} = rand(rng, T)
rand(rng::AbstractRNG, X) = rand(rng, Sampler(rng, X, Val(1)))
# this is needed to disambiguate
rand(rng::AbstractRNG, X::Dims) = rand(rng, Sampler(rng, X, Val(1)))
rand(rng::AbstractRNG=default_rng(), ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))::X
rand(rng::AbstractRNG=default_rng(), ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))::X

rand(X) = rand(default_rng(), X)
rand(::Type{X}) where {X} = rand(default_rng(), X)
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ f11366(x::Type{Ref{T}}) where {T} = Ref{x}


let f(T) = Type{T}
@test Base.return_types(f, Tuple{Type{Int}}) == [Type{Type{Int}}]
@test Base.return_types(f, Tuple{Type{Int}}) == Any[Type{Type{Int}}]
end

# issue #9222
Expand Down
25 changes: 25 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7902,3 +7902,28 @@ struct ModTparamTestStruct{M}; end
end
@test ModTparamTestStruct{@__MODULE__}() == 2
@test ModTparamTestStruct{ModTparamTest}() == 1

# issue #47476
f47476(::Union{Int, NTuple{N,Int}}...) where {N} = N
# force it to populate the MethodInstance specializations cache
# with the correct sparams
code_typed(f47476, (Vararg{Union{Int, NTuple{2,Int}}},));
code_typed(f47476, (Int, Vararg{Union{Int, NTuple{2,Int}}},));
code_typed(f47476, (Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
code_typed(f47476, (Int, Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
code_typed(f47476, (Int, Int, Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
@test f47476(1, 2, 3, 4, 5, 6, (7, 8)) === 2
@test_throws UndefVarError(:N) f47476(1, 2, 3, 4, 5, 6, 7)

vect47476(::Type{T}) where {T} = T
@test vect47476(Type{Type{Type{Int32}}}) === Type{Type{Type{Int32}}}
@test vect47476(Type{Type{Type{Int64}}}) === Type{Type{Type{Int64}}}

g47476(::Union{Nothing,Int,Val{T}}...) where {T} = T
@test_throws UndefVarError(:T) g47476(nothing, 1, nothing, 2, nothing, 3, nothing, 4, nothing, 5)
@test g47476(nothing, 1, nothing, 2, nothing, 3, nothing, 4, nothing, 5, Val(6)) === 6
let spec = only(methods(g47476)).specializations
@test !isempty(spec)
@test any(mi -> mi !== nothing && Base.isvatuple(mi.specTypes), spec)
@test all(mi -> mi === nothing || !Base.has_free_typevars(mi.specTypes), spec)
end
4 changes: 2 additions & 2 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1493,8 +1493,8 @@ end
f(x, y) = x + y
f(x::Int, y) = 2x + y
end
precompile(M.f, (Int, Any))
precompile(M.f, (AbstractFloat, Any))
@test precompile(M.f, (Int, Any))
@test precompile(M.f, (AbstractFloat, Any))
mis = map(methods(M.f)) do m
m.specializations[1]
end
Expand Down

0 comments on commit 770f5a3

Please sign in to comment.