Skip to content

Commit

Permalink
Support mixed constant composites.
Browse files Browse the repository at this point in the history
  • Loading branch information
HansKristian-Work committed Sep 27, 2017
1 parent 5e1d6fb commit ceefae5
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 62 deletions.
58 changes: 37 additions & 21 deletions spirv_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,16 +712,27 @@ struct SPIRConstant : IVariant
struct ConstantVector
{
Constant r[4];
uint32_t id[4] = {}; // If != 0, this is a specialization constant, and we should keep track of it as such.
uint32_t id[4] = {}; // If != 0, this element is a specialization constant, and we should keep track of it as such.
uint32_t vecsize = 1;
};

struct ConstantMatrix
{
ConstantVector c[4];
uint32_t id[4] = {}; // If != 0, this column is a specialization constant, and we should keep track of it as such.
uint32_t columns = 1;
};

inline uint32_t specialization_constant_id(uint32_t col, uint32_t row) const
{
return m.c[col].id[row];
}

inline uint32_t specialization_constant_id(uint32_t col) const
{
return m.id[col];
}

inline uint32_t scalar(uint32_t col = 0, uint32_t row = 0) const
{
return m.c[col].r[row].u32;
Expand Down Expand Up @@ -756,10 +767,12 @@ struct SPIRConstant : IVariant
{
return m.c[0];
}

inline uint32_t vector_size() const
{
return m.c[0].vecsize;
}

inline uint32_t columns() const
{
return m.columns;
Expand Down Expand Up @@ -792,9 +805,6 @@ struct SPIRConstant : IVariant
m.c[0].r[0].u32 = v0;
m.c[0].vecsize = 1;
m.columns = 1;

if (specialized)
m.c[0].id[0] = self;
}

// Construct scalar (64-bit).
Expand All @@ -804,31 +814,37 @@ struct SPIRConstant : IVariant
m.c[0].r[0].u64 = v0;
m.c[0].vecsize = 1;
m.columns = 1;

if (specialized)
m.c[0].id[0] = self;
}

// Construct vector.
// Construct vectors and matrices.
SPIRConstant(uint32_t constant_type_, const SPIRConstant * const *vector_elements, uint32_t num_elements, bool specialized)
: constant_type(constant_type_), specialization(specialized)
{
m.c[0].vecsize = num_elements;
m.columns = 1;
bool matrix = vector_elements[0]->m.c[0].vecsize > 1;

for (uint32_t i = 0; i < num_elements; i++)
if (matrix)
{
m.c[0].r[i] = vector_elements[i]->m.c[0].r[0];
m.c[0].id[i] = vector_elements[i]->m.c[0].id[0];
m.columns = num_elements;

for (uint32_t i = 0; i < num_elements; i++)
{
m.c[i] = vector_elements[i]->m.c[0];
if (vector_elements[i]->specialization)
m.id[i] = vector_elements[i]->self;
}
}
else
{
m.c[0].vecsize = num_elements;
m.columns = 1;

for (uint32_t i = 0; i < num_elements; i++)
{
m.c[0].r[i] = vector_elements[i]->m.c[0].r[0];
if (vector_elements[i]->specialization)
m.c[0].id[i] = vector_elements[i]->self;
}
}
}

// Construct matrix.
SPIRConstant(uint32_t constant_type_, const ConstantVector *vectors, uint32_t num_vectors, bool specialized)
: constant_type(constant_type_), specialization(specialized)
{
m.columns = num_vectors;
memcpy(m.c, vectors, num_vectors * sizeof(*vectors));
}

uint32_t constant_type;
Expand Down
26 changes: 5 additions & 21 deletions spirv_cross.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1686,34 +1686,18 @@ void Compiler::parse(const Instruction &instruction)
if (ctype.basetype == SPIRType::Struct || !ctype.array.empty())
{
set<SPIRConstant>(id, type, ops + 2, length - 2, op == OpSpecConstantComposite);
break;
}

bool matrix = ctype.columns > 1;

if (matrix)
{
uint32_t columns = length - 2;
if (columns > 4)
SPIRV_CROSS_THROW("OpConstantComposite only supports 1, 2, 3 and 4 columns.");

SPIRConstant::ConstantVector c[4];
for (uint32_t i = 0; i < columns; i++)
c[i] = get<SPIRConstant>(ops[2 + i]).vector();
set<SPIRConstant>(id, type, c, columns, op == OpSpecConstantComposite);
}
else
{
uint32_t components = length - 2;
if (components > 4)
SPIRV_CROSS_THROW("OpConstantComposite only supports 1, 2, 3 and 4 components.");
uint32_t elements = length - 2;
if (elements > 4)
SPIRV_CROSS_THROW("OpConstantComposite only supports 1, 2, 3 and 4 elements.");

const SPIRConstant *c[4];
for (uint32_t i = 0; i < components; i++)
for (uint32_t i = 0; i < elements; i++)
c[i] = &get<SPIRConstant>(ops[2 + i]);
set<SPIRConstant>(id, type, c, components, op == OpSpecConstantComposite);
set<SPIRConstant>(id, type, c, elements, op == OpSpecConstantComposite);
}

break;
}

Expand Down
101 changes: 81 additions & 20 deletions spirv_glsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,8 +1342,16 @@ void CompilerGLSL::emit_specialization_constant(const SPIRConstant &constant)
auto &type = get<SPIRType>(constant.constant_type);
auto name = to_name(constant.self);

statement("layout(constant_id = ", get_decoration(constant.self, DecorationSpecId), ") const ",
variable_decl(type, name), " = ", constant_expression(constant), ";");
// Only scalars have constant IDs.
if (has_decoration(constant.self, DecorationSpecId))
{
statement("layout(constant_id = ", get_decoration(constant.self, DecorationSpecId), ") const ",
variable_decl(type, name), " = ", constant_expression(constant), ";");
}
else
{
statement("const ", variable_decl(type, name), " = ", constant_expression(constant), ";");
}
}

void CompilerGLSL::replace_illegal_names()
Expand Down Expand Up @@ -2169,7 +2177,11 @@ string CompilerGLSL::constant_expression(const SPIRConstant &c)
string res = type_to_glsl(get<SPIRType>(c.constant_type)) + "(";
for (uint32_t col = 0; col < c.columns(); col++)
{
res += constant_expression_vector(c, col);
if (options.vulkan_semantics && c.specialization_constant_id(col) != 0)
res += to_name(c.specialization_constant_id(col));
else
res += constant_expression_vector(c, col);

if (col + 1 < c.columns())
res += ", ";
}
Expand All @@ -2188,6 +2200,20 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t
res += type_to_glsl(type) + "(";

bool splat = backend.use_constructor_splatting && c.vector_size() > 1;

if (splat)
{
// Cannot use constant splatting if we have specialization constants somewhere in the vector.
for (uint32_t i = 0; i < c.vector_size(); i++)
{
if (options.vulkan_semantics && c.specialization_constant_id(vector, i) != 0)
{
splat = false;
break;
}
}
}

if (splat)
{
if (type_to_std430_base_size(type) == 8)
Expand Down Expand Up @@ -2219,7 +2245,11 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t
{
for (uint32_t i = 0; i < c.vector_size(); i++)
{
res += convert_to_string(c.scalar_f32(vector, i));
if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0)
res += to_name(c.specialization_constant_id(vector, i));
else
res += convert_to_string(c.scalar_f32(vector, i));

if (backend.float_literal_suffix)
res += "f";
if (i + 1 < c.vector_size())
Expand All @@ -2239,9 +2269,15 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t
{
for (uint32_t i = 0; i < c.vector_size(); i++)
{
res += convert_to_string(c.scalar_f64(vector, i));
if (backend.double_literal_suffix)
res += "lf";
if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0)
res += to_name(c.specialization_constant_id(vector, i));
else
{
res += convert_to_string(c.scalar_f64(vector, i));
if (backend.double_literal_suffix)
res += "lf";
}

if (i + 1 < c.vector_size())
res += ", ";
}
Expand All @@ -2261,11 +2297,17 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t
{
for (uint32_t i = 0; i < c.vector_size(); i++)
{
res += convert_to_string(c.scalar_i64(vector, i));
if (backend.long_long_literal_suffix)
res += "ll";
if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0)
res += to_name(c.specialization_constant_id(vector, i));
else
res += "l";
{
res += convert_to_string(c.scalar_i64(vector, i));
if (backend.long_long_literal_suffix)
res += "ll";
else
res += "l";
}

if (i + 1 < c.vector_size())
res += ", ";
}
Expand All @@ -2285,11 +2327,17 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t
{
for (uint32_t i = 0; i < c.vector_size(); i++)
{
res += convert_to_string(c.scalar_u64(vector, i));
if (backend.long_long_literal_suffix)
res += "ull";
if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0)
res += to_name(c.specialization_constant_id(vector, i));
else
res += "ul";
{
res += convert_to_string(c.scalar_u64(vector, i));
if (backend.long_long_literal_suffix)
res += "ull";
else
res += "ul";
}

if (i + 1 < c.vector_size())
res += ", ";
}
Expand All @@ -2307,9 +2355,15 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t
{
for (uint32_t i = 0; i < c.vector_size(); i++)
{
res += convert_to_string(c.scalar(vector, i));
if (backend.uint32_t_literal_suffix)
res += "u";
if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0)
res += to_name(c.specialization_constant_id(vector, i));
else
{
res += convert_to_string(c.scalar(vector, i));
if (backend.uint32_t_literal_suffix)
res += "u";
}

if (i + 1 < c.vector_size())
res += ", ";
}
Expand All @@ -2323,7 +2377,10 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t
{
for (uint32_t i = 0; i < c.vector_size(); i++)
{
res += convert_to_string(c.scalar_i32(vector, i));
if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0)
res += to_name(c.specialization_constant_id(vector, i));
else
res += convert_to_string(c.scalar_i32(vector, i));
if (i + 1 < c.vector_size())
res += ", ";
}
Expand All @@ -2337,7 +2394,11 @@ string CompilerGLSL::constant_expression_vector(const SPIRConstant &c, uint32_t
{
for (uint32_t i = 0; i < c.vector_size(); i++)
{
res += c.scalar(vector, i) ? "true" : "false";
if (c.vector_size() > 1 && c.specialization_constant_id(vector, i) != 0)
res += to_name(c.specialization_constant_id(vector, i));
else
res += c.scalar(vector, i) ? "true" : "false";

if (i + 1 < c.vector_size())
res += ", ";
}
Expand Down

0 comments on commit ceefae5

Please sign in to comment.