Skip to content

Commit

Permalink
strip out manual heterogeneity-testing boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
neworderofjamie committed Dec 3, 2024
1 parent 2dfceab commit 0182010
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 397 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ class GENN_EXPORT CustomConnectivityUpdateGroupMerged : public GroupMerged<Custo
//----------------------------------------------------------------------------
// Private methods
//----------------------------------------------------------------------------
bool isParamHeterogeneous(const std::string &name) const;
bool isDerivedParamHeterogeneous(const std::string &name) const;

template<typename A>
void addPrivateVarRefAccess(EnvironmentGroupMergedField<CustomConnectivityUpdateGroupMerged> &env, unsigned int batchSize,
std::function<std::string(VarAccessMode, const typename A::RefType&)> getIndexFn)
Expand Down Expand Up @@ -143,9 +140,6 @@ class GENN_EXPORT CustomConnectivityHostUpdateGroupMerged : public GroupMerged<C
//----------------------------------------------------------------------------
// Private methods
//----------------------------------------------------------------------------
bool isParamHeterogeneous(const std::string &name) const;
bool isDerivedParamHeterogeneous(const std::string &name) const;

template<typename A>
void addVars(EnvironmentGroupMergedField<CustomConnectivityHostUpdateGroupMerged> &env, const std::string &count, const BackendBase &backend)
{
Expand Down
10 changes: 0 additions & 10 deletions include/genn/genn/code_generator/customUpdateGroupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,6 @@ class GENN_EXPORT CustomUpdateGroupMerged : public GroupMerged<CustomUpdateInter
// Static constants
//----------------------------------------------------------------------------
static const std::string name;

private:
//----------------------------------------------------------------------------
// Private methods
//----------------------------------------------------------------------------
bool isParamHeterogeneous(const std::string &paramName) const;
bool isDerivedParamHeterogeneous(const std::string &paramName) const;
};

// ----------------------------------------------------------------------------
Expand All @@ -57,9 +50,6 @@ class GENN_EXPORT CustomUpdateWUGroupMergedBase : public GroupMerged<CustomUpdat
//----------------------------------------------------------------------------
// Public API
//----------------------------------------------------------------------------
bool isParamHeterogeneous(const std::string &paramName) const;
bool isDerivedParamHeterogeneous(const std::string &paramName) const;

boost::uuids::detail::sha1::digest_type getHashDigest() const;

void generateCustomUpdate(EnvironmentExternalBase &env, unsigned int batchSize,
Expand Down
72 changes: 21 additions & 51 deletions include/genn/genn/code_generator/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,6 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro
Runtime::MergedDynamicFieldDestinations&>>;
using GetFieldNonNumericValueFunc = std::function<NonNumericFieldValue(Runtime::Runtime&, const GroupInternal&, size_t)>;
using GetFieldNumericValueFunc = std::function<Type::NumericValue(const GroupInternal&, size_t)>;
using IsHeterogeneousFn = bool (G::*)(const std::string&) const;
using IsDynamicFn = bool (GroupInternal::*)(const std::string&) const;
using IsVarInitHeterogeneousFn = bool (G::*)(const std::string&, const std::string&) const;
using GetParamValuesFn = const std::map<std::string, Type::NumericValue> &(GroupInternal::*)(void) const;
Expand Down Expand Up @@ -581,7 +580,7 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro

template<typename I>
void addInitialiserParams(const std::string &fieldSuffix, GetInitialiserFn<I> getInitialiser,
IsHeterogeneousFn isHeterogeneous, std::optional<IsDynamicFn> isDynamic = std::nullopt)
std::optional<IsDynamicFn> isDynamic = std::nullopt)
{
// Loop through params
const auto &initialiser = std::invoke(getInitialiser, this->getGroup().getArchetype());
Expand All @@ -600,26 +599,20 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro
},
"", GroupMergedFieldType::DYNAMIC);
}
// Otherwise, if parameter is heterogeneous across merged group
else if(std::invoke(isHeterogeneous, this->getGroup(), p.name)) {
// Otherwise, add standard field
else {
addField(resolvedType.addConst(), p.name,
resolvedType, p.name + fieldSuffix,
[p, getInitialiser](const auto &g, size_t)
{
return std::invoke(getInitialiser, g).getParams().at(p.name);
});
}
// Otherwise, just add a const-qualified scalar to the type environment
else {
add(resolvedType.addConst(), p.name,
Type::writeNumeric(initialiser.getParams().at(p.name), resolvedType));
}
}
}

template<typename I>
void addInitialiserDerivedParams(const std::string &fieldSuffix, GetInitialiserFn<I> getInitialiser,
IsHeterogeneousFn isHeterogeneous)
void addInitialiserDerivedParams(const std::string &fieldSuffix, GetInitialiserFn<I> getInitialiser)
{
// Loop through params
const auto &initialiser = std::invoke(getInitialiser, this->getGroup().getArchetype());
Expand All @@ -628,24 +621,16 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro
// If parameter is heterogeneous, add scalar field
const auto resolvedType = d.type.resolve(this->getGroup().getTypeContext());
assert(!resolvedType.isPointer());
if (std::invoke(isHeterogeneous, this->getGroup(), d.name)) {
addField(resolvedType.addConst(), d.name, resolvedType, d.name + fieldSuffix,
[d, getInitialiser](const auto &g, size_t)
{
return std::invoke(getInitialiser, g).getDerivedParams().at(d.name);
});
}
// Otherwise, just add a const-qualified scalar to the type environment
else {
add(resolvedType.addConst(), d.name,
Type::writeNumeric(initialiser.getDerivedParams().at(d.name), resolvedType));
}
addField(resolvedType.addConst(), d.name, resolvedType, d.name + fieldSuffix,
[d, getInitialiser](const auto &g, size_t)
{
return std::invoke(getInitialiser, g).getDerivedParams().at(d.name);
});
}
}

template<typename A>
void addVarInitParams(IsVarInitHeterogeneousFn isHeterogeneous,
const std::string &varName, const std::string &fieldSuffix = "")
void addVarInitParams(const std::string &varName, const std::string &fieldSuffix = "")
{
// Loop through parameters
const auto &initialiser = A(this->getGroup().getArchetype()).getInitialisers().at(varName);
Expand All @@ -654,24 +639,16 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro
// If parameter is heterogeneous, add field
const auto resolvedType = p.type.resolve(this->getGroup().getTypeContext());
assert(!resolvedType.isPointer());
if(std::invoke(isHeterogeneous, this->getGroup(), varName, p.name)) {
addField(resolvedType.addConst(), p.name, resolvedType, p.name + varName + fieldSuffix,
[p, varName](const auto &g, size_t)
{
return A(g).getInitialisers().at(varName).getParams().at(p.name);
});
}
// Otherwise, just add a const-qualified scalar to the type environment with archetype value
else {
add(resolvedType.addConst(), p.name,
Type::writeNumeric(initialiser.getParams().at(p.name), resolvedType));
}
addField(resolvedType.addConst(), p.name, resolvedType, p.name + varName + fieldSuffix,
[p, varName](const auto &g, size_t)
{
return A(g).getInitialisers().at(varName).getParams().at(p.name);
});
}
}

template<typename A>
void addVarInitDerivedParams(IsVarInitHeterogeneousFn isHeterogeneous,
const std::string &varName, const std::string &fieldSuffix = "")
void addVarInitDerivedParams(const std::string &varName, const std::string &fieldSuffix = "")
{
// Loop through derived parameters
const auto &initialiser = A(this->getGroup().getArchetype()).getInitialisers().at(varName);
Expand All @@ -680,18 +657,11 @@ class EnvironmentGroupMergedField : public EnvironmentExternalDynamicBase<Enviro
// If derived parameter is heterogeneous, add scalar field
const auto resolvedType = d.type.resolve(this->getGroup().getTypeContext());
assert(!resolvedType.isPointer());
if(std::invoke(isHeterogeneous, this->getGroup(), varName, d.name)) {
addField(resolvedType.addConst(), d.name, resolvedType, d.name + varName + fieldSuffix,
[d, varName](const auto &g, size_t)
{
return A(g).getInitialisers().at(varName).getDerivedParams().at(d.name);
});
}
// Otherwise, just add a const-qualified valuie to the type environment with archetype value
else {
add(resolvedType.addConst(), d.name,
Type::writeNumeric(initialiser.getDerivedParams().at(d.name), resolvedType));
}
addField(resolvedType.addConst(), d.name, resolvedType, d.name + varName + fieldSuffix,
[d, varName](const auto &g, size_t)
{
return A(g).getInitialisers().at(varName).getDerivedParams().at(d.name);
});
}
}

Expand Down
15 changes: 0 additions & 15 deletions include/genn/genn/code_generator/groupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,6 @@ class ChildGroupMerged
//------------------------------------------------------------------------
// Protected API
//------------------------------------------------------------------------
//! Helper to test whether parameter values are heterogeneous within merged group
template<typename P>
bool isParamValueHeterogeneous(const std::string &name, P getParamValuesFn) const
{
// Get value of parameter in archetype group
const auto archetypeValue = getParamValuesFn(getArchetype()).at(name);

// Return true if any parameter values differ from the archetype value
return std::any_of(getGroups().cbegin(), getGroups().cend(),
[&name, archetypeValue, getParamValuesFn](const GroupInternal &g)
{
return (getParamValuesFn(g).at(name) != archetypeValue);
});
}

//! Helper to update hash with the hash of calling getHashableFn on each group
template<typename H>
void updateHash(H getHashableFn, boost::uuids::detail::sha1 &hash) const
Expand Down
44 changes: 0 additions & 44 deletions include/genn/genn/code_generator/initGroupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,6 @@ class InitGroupMergedBase : public B
public:
using B::B;

//----------------------------------------------------------------------------
// Public API
//----------------------------------------------------------------------------
//! Should the var init parameter be implemented heterogeneously?
bool isVarInitParamHeterogeneous(const std::string &varName, const std::string &paramName) const
{
return this->isParamValueHeterogeneous(paramName,
[&varName](const auto &g)
{
return A(g).getInitialisers().at(varName).getParams();
});
}

//! Should the var init derived parameter be implemented heterogeneously?
bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string &paramName) const
{
return this->isParamValueHeterogeneous(paramName,
[&varName](const auto &g)
{
return A(g).getInitialisers().at(varName).getDerivedParams();
});
}
protected:
//----------------------------------------------------------------------------
// Protected methods
Expand Down Expand Up @@ -307,18 +285,6 @@ class GENN_EXPORT SynapseConnectivityInitGroupMerged : public GroupMerged<Synaps
void generateSparseColumnInit(EnvironmentExternalBase &env);
void generateKernelInit(EnvironmentExternalBase &env, unsigned int batchSize);

//! Should the var init parameter be implemented heterogeneously?
bool isVarInitParamHeterogeneous(const std::string &varName, const std::string &paramName) const;

//! Should the var init derived parameter be implemented heterogeneously?
bool isVarInitDerivedParamHeterogeneous(const std::string &varName, const std::string &paramName) const;

//! Should the sparse connectivity initialization parameter be implemented heterogeneously?
bool isSparseConnectivityInitParamHeterogeneous(const std::string &paramName) const;

//! Should the sparse connectivity initialization parameter be implemented heterogeneously?
bool isSparseConnectivityInitDerivedParamHeterogeneous(const std::string &paramName) const;

//----------------------------------------------------------------------------
// Static constants
//----------------------------------------------------------------------------
Expand Down Expand Up @@ -355,16 +321,6 @@ class GENN_EXPORT SynapseConnectivityHostInitGroupMerged : public GroupMerged<Sy
// Static constants
//----------------------------------------------------------------------------
static const std::string name;

private:
//------------------------------------------------------------------------
// Private methods
//------------------------------------------------------------------------
//! Should the connectivity initialization parameter be implemented heterogeneously for EGP init?
bool isConnectivityInitParamHeterogeneous(const std::string &paramName) const;

//! Should the connectivity initialization derived parameter be implemented heterogeneously for EGP init?
bool isConnectivityInitDerivedParamHeterogeneous(const std::string &paramName) const;
};

// ----------------------------------------------------------------------------
Expand Down
35 changes: 0 additions & 35 deletions include/genn/genn/code_generator/neuronUpdateGroupMerged.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase

//! Update hash with child groups
void updateHash(boost::uuids::detail::sha1 &hash) const;

//! Should the current source parameter be implemented heterogeneously?
bool isParamHeterogeneous(const std::string &paramName) const;

//! Should the current source derived parameter be implemented heterogeneously?
bool isDerivedParamHeterogeneous(const std::string &paramName) const;
};

//----------------------------------------------------------------------------
Expand All @@ -53,12 +47,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase

//! Update hash with child groups
void updateHash(boost::uuids::detail::sha1 &hash) const;

//! Should the current source parameter be implemented heterogeneously?
bool isParamHeterogeneous(const std::string &paramName) const;

//! Should the current source derived parameter be implemented heterogeneously?
bool isDerivedParamHeterogeneous(const std::string &paramName) const;
};

//----------------------------------------------------------------------------
Expand Down Expand Up @@ -115,12 +103,7 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase

//! Update hash with child groups
void updateHash(boost::uuids::detail::sha1 &hash) const;

//! Should the current source parameter be implemented heterogeneously?
bool isParamHeterogeneous(const std::string &paramName) const;

//! Should the current source derived parameter be implemented heterogeneously?
bool isDerivedParamHeterogeneous(const std::string &paramName) const;
private:
void generateEventConditionInternal(EnvironmentExternalBase &env, NeuronUpdateGroupMerged &ng,
unsigned int batchSize, BackendBase::GroupHandlerEnv<SynSpikeEvent> genEmitSpikeLikeEvent,
Expand Down Expand Up @@ -149,12 +132,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase

//! Update hash with child groups
void updateHash(boost::uuids::detail::sha1 &hash) const;

//! Should the current source parameter be implemented heterogeneously?
bool isParamHeterogeneous(const std::string &paramName) const;

//! Should the current source derived parameter be implemented heterogeneously?
bool isDerivedParamHeterogeneous(const std::string &paramName) const;
};

//----------------------------------------------------------------------------
Expand All @@ -177,12 +154,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase

//! Update hash with child groups
void updateHash(boost::uuids::detail::sha1 &hash) const;

//! Should the current source parameter be implemented heterogeneously?
bool isParamHeterogeneous(const std::string &paramName) const;

//! Should the current source derived parameter be implemented heterogeneously?
bool isDerivedParamHeterogeneous(const std::string &paramName) const;
};

NeuronUpdateGroupMerged(size_t index, const Type::TypeContext &typeContext,
Expand Down Expand Up @@ -219,12 +190,6 @@ class GENN_EXPORT NeuronUpdateGroupMerged : public NeuronGroupMergedBase
const std::vector<SynSpikeEvent> &getMergedSpikeEventGroups() const{ return m_MergedSpikeEventGroups; }
const std::vector<InSynWUMPostCode> &getMergedInSynWUMPostCodeGroups() const { return m_MergedInSynWUMPostCodeGroups; }
const std::vector<OutSynWUMPreCode> &getMergedOutSynWUMPreCodeGroups() const { return m_MergedOutSynWUMPreCodeGroups; }

//! Should the parameter be implemented heterogeneously?
bool isParamHeterogeneous(const std::string &paramName) const;

//! Should the derived parameter be implemented heterogeneously?
bool isDerivedParamHeterogeneous(const std::string &paramName) const;

//----------------------------------------------------------------------------
// Static constants
Expand Down
Loading

0 comments on commit 0182010

Please sign in to comment.