Skip to content
This repository has been archived by the owner on Jul 8, 2022. It is now read-only.

Commit

Permalink
DSLX: Converting tuple indexing from Python form to Rust form.
Browse files Browse the repository at this point in the history
Like with Rust (https://doc.rust-lang.org/reference/tokens.html#tuple-index), the index itself must be a literal number, and not a parametric value.

Fixes google#633.

PiperOrigin-RevId: 457735238
  • Loading branch information
Rob Springer authored and copybara-github committed Jun 28, 2022
1 parent 1bcd0f9 commit eeacef1
Show file tree
Hide file tree
Showing 45 changed files with 298 additions and 112 deletions.
4 changes: 3 additions & 1 deletion docs_src/dslx_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,12 @@ example, to access the second element of a tuple (index 1):
#![test]
fn test_tuple_access() {
let t = (u32:2, u8:3);
assert_eq(u8:3, t[1])
assert_eq(u8:3, t.1)
}
```

Such indices can only be numeric literals; parametric symbols are not allowed.

Tuples can be "destructured", similarly to how pattern matching works in `match`
expressions, which provides a convenient syntax to name elements of a tuple for
subsequent use. See `a` and `b` in the following:
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ cc_library(
":scanner",
":type_and_bindings",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/status",
Expand Down
21 changes: 21 additions & 0 deletions xls/dslx/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ std::string_view AstNodeKindToString(AstNodeKind kind) {
return "channel declaration";
case AstNodeKind::kParametricBinding:
return "parametric binding";
case AstNodeKind::kTupleIndex:
return "tuple index";
}
XLS_LOG(FATAL) << "Out-of-range AstNodeKind: " << static_cast<int>(kind);
}
Expand Down Expand Up @@ -1711,6 +1713,25 @@ std::string QuickCheck::ToString() const {
f_->ToString());
}

TupleIndex::TupleIndex(Module* owner, Span span, Expr* lhs, Number* index)
: Expr(owner, std::move(span)), lhs_(lhs), index_(index) {}

absl::Status TupleIndex::Accept(AstNodeVisitor* v) const {
return v->HandleTupleIndex(this);
}

absl::Status TupleIndex::AcceptExpr(ExprVisitor* v) const {
return v->HandleTupleIndex(this);
}

std::string TupleIndex::ToString() const {
return absl::StrCat(lhs_->ToString(), ".", index_->ToString());
}

std::vector<AstNode*> TupleIndex::GetChildren(bool want_types) const {
return {lhs_, index_};
}

std::string XlsTuple::ToString() const {
std::string result = "(";
for (int64_t i = 0; i < members_.size(); ++i) {
Expand Down
22 changes: 22 additions & 0 deletions xls/dslx/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ bool IsOneOf(ObjT* obj) {
X(String) \
X(StructInstance) \
X(Ternary) \
X(TupleIndex) \
X(Unop) \
X(XlsTuple)

Expand Down Expand Up @@ -249,6 +250,7 @@ enum class AstNodeKind {
kLet,
kChannelDecl,
kParametricBinding,
kTupleIndex,
};

std::string_view AstNodeKindToString(AstNodeKind kind);
Expand Down Expand Up @@ -621,6 +623,7 @@ class ExprVisitor {
virtual absl::Status HandleSplatStructInstance(
const SplatStructInstance* expr) = 0;
virtual absl::Status HandleTernary(const Ternary* expr) = 0;
virtual absl::Status HandleTupleIndex(const TupleIndex* expr) = 0;
virtual absl::Status HandleUnop(const Unop* expr) = 0;
virtual absl::Status HandleXlsTuple(const XlsTuple* expr) = 0;
};
Expand Down Expand Up @@ -2265,6 +2268,25 @@ class QuickCheck : public AstNode {
absl::optional<int64_t> test_count_;
};

// Represents an index into a tuple, e.g., "(u32:7, u32:8).1".
class TupleIndex : public Expr {
public:
TupleIndex(Module* owner, Span span, Expr* lhs, Number* index);
AstNodeKind kind() const { return AstNodeKind::kTupleIndex; }
absl::Status Accept(AstNodeVisitor* v) const override;
absl::Status AcceptExpr(ExprVisitor* v) const override;
absl::string_view GetNodeTypeName() const override { return "TupleIndex"; }
std::string ToString() const override;
std::vector<AstNode*> GetChildren(bool want_types) const override;

Expr* lhs() const { return lhs_; }
Number* index() const { return index_; }

private:
Expr* lhs_;
Number* index_;
};

// Represents an XLS tuple expression.
class XlsTuple : public Expr {
public:
Expand Down
8 changes: 8 additions & 0 deletions xls/dslx/ast_cloner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,14 @@ class AstCloner : public AstNodeVisitor {
return absl::OkStatus();
}

absl::Status HandleTupleIndex(const TupleIndex* n) override {
XLS_RETURN_IF_ERROR(VisitChildren(n));
old_to_new_[n] = module_->Make<TupleIndex>(
n->span(), down_cast<Expr*>(old_to_new_.at(n->lhs())),
down_cast<Number*>(old_to_new_.at(n->index())));
return absl::OkStatus();
}

absl::Status HandleTupleTypeAnnotation(
const TupleTypeAnnotation* n) override {
XLS_RETURN_IF_ERROR(VisitChildren(n));
Expand Down
12 changes: 12 additions & 0 deletions xls/dslx/ast_cloner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,18 @@ TEST(AstClonerTest, NormalFor) {
EXPECT_EQ(kProgram, clone->ToString());
}

TEST(AstClonerTest, TupleIndex) {
constexpr absl::string_view kProgram = R"(fn main() -> u32 {
(u8:8, u16:16, u32:32, u64:64).2
})";

XLS_ASSERT_OK_AND_ASSIGN(auto module,
ParseModule(kProgram, "fake_path.x", "the_module"));
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Module> clone,
CloneModule(module.get()));
EXPECT_EQ(kProgram, clone->ToString());
}

TEST(AstClonerTest, SendsAndRecvsAndSpawns) {
constexpr absl::string_view kProgram = R"(import other_module
proc MyProc {
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ std::string Bytecode::MatchArmItem::ToString() const {
}

DEF_UNARY_BUILDER(Dup);
DEF_UNARY_BUILDER(Index);
DEF_UNARY_BUILDER(Invert);
DEF_UNARY_BUILDER(JumpDest);
DEF_UNARY_BUILDER(LogicalOr);
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ class Bytecode {

static Bytecode MakeDup(Span span);
static Bytecode MakeFail(Span span, std::string);
static Bytecode MakeIndex(Span span);
static Bytecode MakeInvert(Span span);
static Bytecode MakeJumpDest(Span span);
static Bytecode MakeJumpRelIf(Span span, JumpTarget target);
Expand Down
24 changes: 15 additions & 9 deletions xls/dslx/bytecode_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class NameDefCollector : public AstNodeVisitor {
}
DEFAULT_HANDLER(Ternary);
DEFAULT_HANDLER(TestProc);
DEFAULT_HANDLER(TupleIndex);
DEFAULT_HANDLER(TupleTypeAnnotation);
DEFAULT_HANDLER(TypeDef);
DEFAULT_HANDLER(TypeRef);
Expand Down Expand Up @@ -251,9 +252,8 @@ absl::Status BytecodeEmitter::HandleAttr(const Attr* node) {

// This indexing literal needs to be unsigned since InterpValue::Index
// requires an unsigned value.
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kLiteral,
InterpValue::MakeU64(member_index)));
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kIndex));
Add(Bytecode::MakeLiteral(node->span(), InterpValue::MakeU64(member_index)));
Add(Bytecode::MakeIndex(node->span()));
return absl::OkStatus();
}

Expand Down Expand Up @@ -610,7 +610,7 @@ absl::Status BytecodeEmitter::HandleFor(const For* node) {
Bytecode::SlotIndex(iterable_slot)));
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kLoad,
Bytecode::SlotIndex(index_slot)));
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kIndex));
Add(Bytecode::MakeIndex(node->span()));
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kSwap));
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kCreateTuple,
Bytecode::NumElements(2)));
Expand Down Expand Up @@ -739,7 +739,7 @@ absl::Status BytecodeEmitter::HandleIndex(const Index* node) {
// Otherwise, it's a regular [array or tuple] index op.
Expr* expr = absl::get<Expr*>(node->rhs());
XLS_RETURN_IF_ERROR(expr->AcceptExpr(this));
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kIndex));
Add(Bytecode::MakeIndex(node->span()));
return absl::OkStatus();
}

Expand Down Expand Up @@ -1093,9 +1093,8 @@ absl::Status BytecodeEmitter::HandleSplatStructInstance(
XLS_RETURN_IF_ERROR(new_members.at(member_name)->AcceptExpr(this));
} else {
XLS_RETURN_IF_ERROR(node->splatted()->AcceptExpr(this));
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kLiteral,
InterpValue::MakeU64(i)));
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kIndex));
Add(Bytecode::MakeLiteral(node->span(), InterpValue::MakeU64(i)));
Add(Bytecode::MakeIndex(node->span()));
}
}

Expand All @@ -1105,11 +1104,18 @@ absl::Status BytecodeEmitter::HandleSplatStructInstance(
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleTupleIndex(const TupleIndex* node) {
XLS_RETURN_IF_ERROR(node->lhs()->AcceptExpr(this));
XLS_RETURN_IF_ERROR(node->index()->AcceptExpr(this));
Add(Bytecode::MakeIndex(node->span()));
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleUnop(const Unop* node) {
XLS_RETURN_IF_ERROR(node->operand()->AcceptExpr(this));
switch (node->unop_kind()) {
case UnopKind::kInvert:
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kInvert));
Add(Bytecode::MakeInvert(node->span()));
break;
case UnopKind::kNegate:
bytecode_.push_back(Bytecode(node->span(), Bytecode::Op::kNegate));
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/bytecode_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class BytecodeEmitter : public ExprVisitor {
absl::Status HandleSplatStructInstance(
const SplatStructInstance* node) override;
absl::Status HandleTernary(const Ternary* node) override;
absl::Status HandleTupleIndex(const TupleIndex* node) override;
absl::Status HandleUnop(const Unop* node) override;
absl::Status HandleXlsTuple(const XlsTuple* node) override;

Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/bytecode_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ fn index_tuple() -> u32 {
let a = (u16:0, u32:1, u64:2);
let b = (bits[128]:3, bits[32]:4);
a[1] + b[1]
a.1 + b.1
}
)";

Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/bytecode_interpreter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ fn index_tuple() -> u32 {
let a = (u32:0, (u32:1, u32:2));
let b = ((u32:3, (u32:4,)), u32:5);
a[1][1] + b[0][1][0]
a.1.1 + b.0.1.0
})";

auto import_data = CreateImportDataForTest();
Expand Down
21 changes: 21 additions & 0 deletions xls/dslx/constexpr_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ class NameRefCollector : public ExprVisitor {
XLS_RETURN_IF_ERROR(expr->operand()->AcceptExpr(this));
return absl::OkStatus();
}
absl::Status HandleTupleIndex(const TupleIndex* expr) override {
XLS_RETURN_IF_ERROR(expr->lhs()->AcceptExpr(this));
return absl::OkStatus();
}
absl::Status HandleXlsTuple(const XlsTuple* expr) override {
for (const Expr* member : expr->members()) {
XLS_RETURN_IF_ERROR(member->AcceptExpr(this));
Expand Down Expand Up @@ -630,6 +634,23 @@ absl::Status ConstexprEvaluator::HandleUnop(const Unop* expr) {
return InterpretExpr(expr);
}

absl::Status ConstexprEvaluator::HandleTupleIndex(const TupleIndex* expr) {
// No need to fire up the interpreter. This one is easy.
GET_CONSTEXPR_OR_RETURN(InterpValue tuple, expr->lhs());
GET_CONSTEXPR_OR_RETURN(InterpValue index, expr->index());

XLS_ASSIGN_OR_RETURN(uint64_t index_value, index.GetBitValueUint64());
XLS_ASSIGN_OR_RETURN(const std::vector<InterpValue>* values,
tuple.GetValues());
if (index_value < 0 || index_value > values->size()) {
return absl::InvalidArgumentError(
absl::StrFormat("%s: Out-of-range tuple index: %d vs %d.",
expr->span().ToString(), index_value, values->size()));
}
type_info_->NoteConstExpr(expr, values->at(index_value));
return absl::OkStatus();
}

absl::Status ConstexprEvaluator::HandleXlsTuple(const XlsTuple* expr) {
std::vector<InterpValue> values;
for (const Expr* member : expr->members()) {
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/constexpr_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class ConstexprEvaluator : public xls::dslx::ExprVisitor {
absl::Status HandleSplatStructInstance(
const SplatStructInstance* expr) override;
absl::Status HandleTernary(const Ternary* expr) override;
absl::Status HandleTupleIndex(const TupleIndex* expr) override;
absl::Status HandleUnop(const Unop* expr) override;
absl::Status HandleXlsTuple(const XlsTuple* expr) override;

Expand Down
23 changes: 22 additions & 1 deletion xls/dslx/constexpr_evaluator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ fn main() -> u32 {
let upbeef = beef + i;
let beeves = u32[4]:[0, 1, 2, 3];
let beef_tuple = (upbeef, beeves[u32:2]);
upbeef + beeves[u32:1] + beef_tuple[u32:1]
upbeef + beeves[u32:1] + beef_tuple.1
} (beef)
})";

Expand Down Expand Up @@ -589,5 +589,26 @@ fn main() -> u32 {
EXPECT_EQ(value.GetBitValueInt64().value(), 6);
}

TEST(ConstexprEvaluatorTest, BasicTupleIndex) {
constexpr absl::string_view kProgram = R"(
fn main() -> u32 {
(u64:64, u32:32, u16:16, u8:8).1
}
)";

ImportData import_data(CreateImportDataForTest());
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule tm,
ParseAndTypecheck(kProgram, "test.x", "test", &import_data));

XLS_ASSERT_OK_AND_ASSIGN(Function * f,
tm.module->GetMemberOrError<Function>("main"));
XLS_ASSERT_OK(ConstexprEvaluator::Evaluate(
&import_data, tm.type_info, SymbolicBindings(), f->body(), nullptr));
XLS_ASSERT_OK_AND_ASSIGN(InterpValue value,
tm.type_info->GetConstExpr(f->body()));
EXPECT_EQ(value.GetBitValueInt64().value(), 32);
}

} // namespace
} // namespace xls::dslx
Loading

0 comments on commit eeacef1

Please sign in to comment.