Skip to content
This repository has been archived by the owner on Jul 8, 2022. It is now read-only.

Commit

Permalink
Enabling typechecking of top-level procs whose config functions take …
Browse files Browse the repository at this point in the history
…channel params.

This enables IR conversion of the same.

This change is necessary because converting a construct to IR requires that it has been typechecked, but prior to this change, any proc with config function params could not be typechecked without a driving instantiation.

Naively enabling typechecking of such procs was problematic, because DeduceSpawn requires constexpr eval of all arguments, but a top-level proc has nothing to provide such arguments. Thus, only channels are allowed as arguments to top-level procs, and dummy channels are created in DeduceParam.

With that in place, interpretation would then be broken, because a TypeInfo recurses through its parent to locate items and constexpr values, and so a proc instantiated by a TestProc or other proc would only ever find the dummy channels created by DeduceParam, breaking network instantiation. To fix that, top-level procs are typechecked in a child TypeInfo of the top-level module TypeInfo, instead of that top-level TypeInfo itself.

PiperOrigin-RevId: 458523040
  • Loading branch information
Rob Springer authored and copybara-github committed Jul 1, 2022
1 parent a93c40b commit 42efb43
Show file tree
Hide file tree
Showing 15 changed files with 226 additions and 101 deletions.
4 changes: 4 additions & 0 deletions xls/dslx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ cc_test(
":create_import_data",
":import_data",
":parse_and_typecheck",
":symbolic_bindings",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -182,6 +183,7 @@ cc_test(
":import_data",
":interp_value",
":parse_and_typecheck",
":symbolic_bindings",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -917,6 +919,7 @@ cc_library(
":errors",
":import_routines",
":parametric_instantiator",
":symbolic_bindings",
":type_info_to_proto",
"//xls/common/status:status_macros",
"@com_github_google_re2//:re2",
Expand All @@ -932,6 +935,7 @@ cc_test(
":ast",
":create_import_data",
":parse_and_typecheck",
":symbolic_bindings",
":type_info_to_proto",
":typecheck",
"//xls/common:xls_gunit_main",
Expand Down
25 changes: 19 additions & 6 deletions xls/dslx/bytecode_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "xls/dslx/create_import_data.h"
#include "xls/dslx/import_data.h"
#include "xls/dslx/parse_and_typecheck.h"
#include "xls/dslx/symbolic_bindings.h"

namespace xls::dslx {
namespace {
Expand Down Expand Up @@ -1206,10 +1207,11 @@ proc Foo {
TypecheckedModule tm,
ParseAndTypecheck(kProgram, "test.x", "test", &import_data));
XLS_ASSERT_OK_AND_ASSIGN(Proc * p, tm.module->GetMemberOrError<Proc>("Foo"));
XLS_ASSERT_OK_AND_ASSIGN(TypeInfo * ti,
tm.type_info->GetTopLevelProcTypeInfo(p));
XLS_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<BytecodeFunction> bf,
BytecodeEmitter::Emit(&import_data, tm.type_info, p->config(),
SymbolicBindings()));
BytecodeEmitter::Emit(&import_data, ti, p->config(), SymbolicBindings()));
const std::vector<Bytecode>& config_bytecodes = bf->bytecodes();
ASSERT_EQ(config_bytecodes.size(), 7);
const std::vector<std::string> kConfigExpected = {
Expand Down Expand Up @@ -1261,11 +1263,20 @@ proc Parent {
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule tm,
ParseAndTypecheck(kProgram, "test.x", "test", &import_data));
XLS_ASSERT_OK_AND_ASSIGN(Proc * p,
XLS_ASSERT_OK_AND_ASSIGN(Proc * parent,
tm.module->GetMemberOrError<Proc>("Parent"));
XLS_ASSERT_OK_AND_ASSIGN(Proc * child,
tm.module->GetMemberOrError<Proc>("Child"));
Spawn* spawn = down_cast<Spawn*>(
down_cast<Let*>(parent->config()->body()->body())->body());
XLS_ASSERT_OK_AND_ASSIGN(TypeInfo * parent_ti,
tm.type_info->GetTopLevelProcTypeInfo(parent));
TypeInfo* child_ti =
parent_ti->GetInvocationTypeInfo(spawn->config(), SymbolicBindings())
.value();
XLS_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<BytecodeFunction> bf,
BytecodeEmitter::Emit(&import_data, tm.type_info, p->config(),
BytecodeEmitter::Emit(&import_data, child_ti, child->config(),
SymbolicBindings()));
const std::vector<Bytecode>& config_bytecodes = bf->bytecodes();
ASSERT_EQ(config_bytecodes.size(), 8);
Expand All @@ -1279,11 +1290,13 @@ proc Parent {
}

std::vector<NameDef*> members;
for (const Param* member : p->members()) {
for (const Param* member : child->members()) {
members.push_back(member->name_def());
}
child_ti = parent_ti->GetInvocationTypeInfo(spawn->next(), SymbolicBindings())
.value();
XLS_ASSERT_OK_AND_ASSIGN(
bf, BytecodeEmitter::EmitProcNext(&import_data, tm.type_info, p->next(),
bf, BytecodeEmitter::EmitProcNext(&import_data, child_ti, child->next(),
SymbolicBindings(), members));
const std::vector<Bytecode>& next_bytecodes = bf->bytecodes();
ASSERT_EQ(next_bytecodes.size(), 16);
Expand Down
25 changes: 14 additions & 11 deletions xls/dslx/bytecode_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1561,10 +1561,11 @@ absl::Status ProcConfigBytecodeInterpreter::EvalSpawn(
std::vector<ProcInstance>* proc_instances) {
auto get_parametric_type_info =
[type_info](const Spawn* spawn, const Invocation* invoc,
const SymbolicBindings& caller_bindings)
const std::optional<SymbolicBindings>& caller_bindings)
-> absl::StatusOr<TypeInfo*> {
absl::optional<TypeInfo*> maybe_type_info =
type_info->GetInvocationTypeInfo(invoc, caller_bindings);
std::optional<TypeInfo*> maybe_type_info = type_info->GetInvocationTypeInfo(
invoc, caller_bindings.has_value() ? caller_bindings.value()
: SymbolicBindings());
if (!maybe_type_info.has_value()) {
return absl::InternalError(
absl::StrCat("Could not find type info for invocation ",
Expand All @@ -1573,15 +1574,17 @@ absl::Status ProcConfigBytecodeInterpreter::EvalSpawn(
return maybe_type_info.value();
};

if (proc->IsParametric()) {
// We need to get a new TI if there's a spawn, i.e., this isn't a top-level
// proc instantiation, to avoid constexpr values from colliding between
// different proc instantiations.
if (maybe_spawn.has_value()) {
// We're guaranteed that these have values if the proc is parametric (the
// root proc can't be parametric).
XLS_RET_CHECK(caller_bindings.has_value());
XLS_RET_CHECK(maybe_spawn.has_value());
XLS_ASSIGN_OR_RETURN(type_info,
get_parametric_type_info(maybe_spawn.value(),
maybe_spawn.value()->config(),
caller_bindings.value()));
caller_bindings));
}

XLS_ASSIGN_OR_RETURN(std::unique_ptr<BytecodeFunction> config_bf,
Expand Down Expand Up @@ -1611,11 +1614,11 @@ absl::Status ProcConfigBytecodeInterpreter::EvalSpawn(
member_defs.push_back(param->name_def());
}

if (proc->IsParametric()) {
XLS_ASSIGN_OR_RETURN(type_info,
get_parametric_type_info(maybe_spawn.value(),
maybe_spawn.value()->next(),
caller_bindings.value()));
if (maybe_spawn.has_value()) {
XLS_ASSIGN_OR_RETURN(
type_info,
get_parametric_type_info(maybe_spawn.value(),
maybe_spawn.value()->next(), caller_bindings));
}

XLS_ASSIGN_OR_RETURN(
Expand Down
10 changes: 5 additions & 5 deletions xls/dslx/constexpr_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <variant>

#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/types/variant.h"
#include "xls/common/status/status_macros.h"
Expand Down Expand Up @@ -329,7 +330,7 @@ absl::Status ConstexprEvaluator::HandleCast(const Cast* expr) {
}

// Creates an InterpValue for the described channel or array of channels.
absl::StatusOr<InterpValue> CreateChannelValue(
absl::StatusOr<InterpValue> ConstexprEvaluator::CreateChannelValue(
const ConcreteType* concrete_type) {
if (auto* array_type = dynamic_cast<const ArrayType*>(concrete_type)) {
XLS_ASSIGN_OR_RETURN(int dim_int, array_type->size().GetAsInt64());
Expand Down Expand Up @@ -374,10 +375,9 @@ absl::Status ConstexprEvaluator::HandleChannelDecl(const ChannelDecl* expr) {
expr->span(), tuple_type, "ChannelDecl type was a two-element tuple.");
}

absl::StatusOr<InterpValue> channels =
CreateChannelValue(&tuple_type->GetMemberType(0));
type_info_->NoteConstExpr(
expr, InterpValue::MakeTuple({channels.value(), channels.value()}));
XLS_ASSIGN_OR_RETURN(InterpValue channel,
CreateChannelValue(&tuple_type->GetMemberType(0)));
type_info_->NoteConstExpr(expr, InterpValue::MakeTuple({channel, channel}));
return absl::OkStatus();
}

Expand Down
3 changes: 3 additions & 0 deletions xls/dslx/constexpr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class ConstexprEvaluator : public xls::dslx::ExprVisitor {
absl::Status HandleUnop(const Unop* expr) override;
absl::Status HandleXlsTuple(const XlsTuple* expr) override;

static absl::StatusOr<InterpValue> CreateChannelValue(
const ConcreteType* concrete_type);

private:
ConstexprEvaluator(ImportData* import_data, TypeInfo* type_info,
SymbolicBindings bindings,
Expand Down
28 changes: 27 additions & 1 deletion xls/dslx/deduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,33 @@ absl::StatusOr<std::unique_ptr<ConcreteType>> DeduceUnop(const Unop* node,

absl::StatusOr<std::unique_ptr<ConcreteType>> DeduceParam(const Param* node,
DeduceCtx* ctx) {
return ctx->Deduce(node->type_annotation());
XLS_ASSIGN_OR_RETURN(auto concrete_type,
ctx->Deduce(node->type_annotation()));
Function* f = dynamic_cast<Function*>(node->parent());
if (f == nullptr) {
return concrete_type;
}

// When deducing a proc at top level, we won't have constexpr values for its
// config params, which will cause Spawn deduction to fail, so we need to
// create dummy InterpValues for its parameter channels.
// Other types of params aren't allowed, example: a proc member could be
// assigned a constexpr value based on the sum of dummy values.
// Stack depth 2: Module "<top>" + the config function being looked at.
bool is_root_proc =
f->tag() == Function::Tag::kProcConfig && ctx->fn_stack().size() == 2;
bool is_channel_param =
dynamic_cast<ChannelType*>(concrete_type.get()) != nullptr;
bool is_param_constexpr = ctx->type_info()->IsKnownConstExpr(node);
if (is_root_proc && is_channel_param && !is_param_constexpr) {
XLS_ASSIGN_OR_RETURN(
InterpValue value,
ConstexprEvaluator::CreateChannelValue(concrete_type.get()));
ctx->type_info()->NoteConstExpr(node, value);
ctx->type_info()->NoteConstExpr(node->name_def(), value);
}

return concrete_type;
}

absl::StatusOr<std::unique_ptr<ConcreteType>> DeduceConstantDef(
Expand Down
7 changes: 7 additions & 0 deletions xls/dslx/deduce_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ class DeduceCtx {
type_info_ = type_info_owner().New(module(), /*parent=*/type_info_).value();
}

// Puts the given TypeInfo on top of the current stack.
absl::Status PushTypeInfo(TypeInfo* ti) {
XLS_RET_CHECK_EQ(ti->parent(), type_info_);
type_info_ = ti;
return absl::OkStatus();
}

// Pops the current type_info_ and sets the type_info_ to be the popped
// value's parent (conceptually an inverse of AddDerivedTypeInfo()).
absl::Status PopDerivedTypeInfo() {
Expand Down
20 changes: 16 additions & 4 deletions xls/dslx/extract_conversion_order.cc
Original file line number Diff line number Diff line change
Expand Up @@ -774,10 +774,13 @@ absl::StatusOr<std::vector<ConversionRecord>> GetOrder(Module* module,
// Collect the top level procs.
XLS_ASSIGN_OR_RETURN(std::vector<Proc*> top_level_procs,
GetTopLevelProcs(module, type_info));
// Get the Order for each top level proc.
// Get the order for each top level proc.
for (Proc* proc : top_level_procs) {
XLS_ASSIGN_OR_RETURN(TypeInfo * proc_ti,
type_info->GetTopLevelProcTypeInfo(proc));

XLS_ASSIGN_OR_RETURN(std::vector<ConversionRecord> proc_ready,
GetOrderForProc(proc, type_info, /*is_top=*/false));
GetOrderForProc(proc, proc_ti, /*is_top=*/false));
ready.insert(ready.end(), proc_ready.begin(), proc_ready.end());
}

Expand All @@ -789,8 +792,10 @@ absl::StatusOr<std::vector<ConversionRecord>> GetOrder(Module* module,
{}));
}
for (TestProc* test : module->GetProcTests()) {
XLS_ASSIGN_OR_RETURN(TypeInfo * proc_ti,
type_info->GetTopLevelProcTypeInfo(test->proc()));
XLS_ASSIGN_OR_RETURN(std::vector<ConversionRecord> proc_ready,
GetOrderForProc(test, type_info, /*is_top=*/false));
GetOrderForProc(test, proc_ti, /*is_top=*/false));
ready.insert(ready.end(), proc_ready.begin(), proc_ready.end());
}
}
Expand All @@ -811,14 +816,21 @@ absl::StatusOr<std::vector<ConversionRecord>> GetOrderForEntry(
std::vector<ConversionRecord> ready;
if (absl::holds_alternative<Function*>(entry)) {
Function* f = absl::get<Function*>(entry);
if (f->proc().has_value()) {
XLS_ASSIGN_OR_RETURN(
type_info, type_info->GetTopLevelProcTypeInfo(f->proc().value()));
}
XLS_RETURN_IF_ERROR(AddToReady(f,
/*invocation=*/nullptr, f->owner(),
type_info, SymbolicBindings(), &ready, {},
/*is_top=*/true));
return ready;
}

Proc* p = absl::get<Proc*>(entry);
return GetOrderForProc(p, type_info, /*is_top=*/true);
XLS_ASSIGN_OR_RETURN(TypeInfo * new_ti,
type_info->GetTopLevelProcTypeInfo(p));
return GetOrderForProc(p, new_ti, /*is_top=*/true);
}

} // namespace xls::dslx
37 changes: 0 additions & 37 deletions xls/dslx/extract_conversion_order_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -550,43 +550,6 @@ proc main {
EXPECT_EQ(order[17].proc_id().value().ToString(), "main->p2:0");
}

TEST(ExtractConversionOrderTest, FunctionProcMixed) {
constexpr absl::string_view kProgram = R"(
fn f0() -> u32 {
u32:42
}
fn f1() -> u32 {
u32:24
}
proc main {
config() { () }
next(tok: token, x: u32) {
(f0(),)
}
}
)";
auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule tm,
ParseAndTypecheck(kProgram, "test.x", "test", &import_data));
XLS_ASSERT_OK_AND_ASSIGN(std::vector<ConversionRecord> order,
GetOrder(tm.module, tm.type_info));
ASSERT_EQ(4, order.size());
ASSERT_FALSE(order[0].proc_id().has_value());
ASSERT_FALSE(order[1].proc_id().has_value());
ASSERT_TRUE(order[2].proc_id().has_value());
ASSERT_TRUE(order[3].proc_id().has_value());
EXPECT_EQ(order[0].f()->identifier(), "f0");
EXPECT_EQ(order[1].f()->identifier(), "f1");
EXPECT_EQ(order[2].f()->identifier(), "main.config");
EXPECT_EQ(order[2].proc_id().value().ToString(), "main:0");
EXPECT_EQ(order[3].f()->identifier(), "main.next");
EXPECT_EQ(order[3].proc_id().value().ToString(), "main:0");
}

TEST(ExtractConversionOrderTest, ProcNetworkWithTwoTopLevelProcs) {
constexpr absl::string_view kProgram = R"(
proc p2 {
Expand Down
1 change: 0 additions & 1 deletion xls/dslx/ir_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3161,7 +3161,6 @@ absl::StatusOr<std::unique_ptr<ConcreteType>> FunctionConverter::ResolveType(
XLS_RET_CHECK(current_type_info_ != nullptr);
absl::optional<const ConcreteType*> t = current_type_info_->GetItem(node);
if (!t.has_value()) {
XLS_LOG(INFO) << "NODE: " << node << " : TI: " << current_type_info_;
return ConversionErrorStatus(
node->GetSpan(),
absl::StrFormat(
Expand Down
22 changes: 11 additions & 11 deletions xls/dslx/proc_config_ir_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ class ProcConfigIrConverter : public AstNodeVisitorWithDefault {
const SymbolicBindings& bindings,
const ProcId& proc_id);

absl::Status HandleBlock(const Block* node);
absl::Status HandleChannelDecl(const ChannelDecl* node);
absl::Status HandleFunction(const Function* node);
absl::Status HandleInvocation(const Invocation* node);
absl::Status HandleLet(const Let* node);
absl::Status HandleNameRef(const NameRef* node);
absl::Status HandleNumber(const Number* node);
absl::Status HandleParam(const Param* node);
absl::Status HandleSpawn(const Spawn* node);
absl::Status HandleStructInstance(const StructInstance* node);
absl::Status HandleXlsTuple(const XlsTuple* node);
absl::Status HandleBlock(const Block* node) override;
absl::Status HandleChannelDecl(const ChannelDecl* node) override;
absl::Status HandleFunction(const Function* node) override;
absl::Status HandleInvocation(const Invocation* node) override;
absl::Status HandleLet(const Let* node) override;
absl::Status HandleNameRef(const NameRef* node) override;
absl::Status HandleNumber(const Number* node) override;
absl::Status HandleParam(const Param* node) override;
absl::Status HandleSpawn(const Spawn* node) override;
absl::Status HandleStructInstance(const StructInstance* node) override;
absl::Status HandleXlsTuple(const XlsTuple* node) override;

// Sets the mapping from the elements in the config-ending tuple to the
// corresponding Proc members.
Expand Down
11 changes: 6 additions & 5 deletions xls/dslx/run_routines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,18 @@ absl::Status RunTestProc(ImportData* import_data, TypeInfo* type_info,
auto cache = std::make_unique<BytecodeCache>(import_data);
import_data->SetBytecodeCache(std::move(cache));

XLS_ASSIGN_OR_RETURN(TypeInfo * ti,
type_info->GetTopLevelProcTypeInfo(tp->proc()));
XLS_ASSIGN_OR_RETURN(
std::unique_ptr<BytecodeFunction> bf,
BytecodeEmitter::Emit(import_data, type_info, tp->proc()->config(),
BytecodeEmitter::Emit(import_data, ti, tp->proc()->config(),
absl::nullopt));

std::vector<ProcInstance> proc_instances;
XLS_ASSIGN_OR_RETURN(
InterpValue terminator,
type_info->GetConstExpr(tp->proc()->config()->params()[0]));
XLS_ASSIGN_OR_RETURN(InterpValue terminator,
ti->GetConstExpr(tp->proc()->config()->params()[0]));
XLS_RETURN_IF_ERROR(ProcConfigBytecodeInterpreter::InitializeProcNetwork(
import_data, type_info, tp->proc(), terminator, tp->next_args(),
import_data, ti, tp->proc(), terminator, tp->next_args(),
&proc_instances));

std::shared_ptr<InterpValue::Channel> term_chan =
Expand Down
Loading

0 comments on commit 42efb43

Please sign in to comment.