Skip to content

Commit

Permalink
[XLS] Add cross-activation token support to TokenDependencyPass
Browse files Browse the repository at this point in the history
This could previously cause optimization to fail when a state element contained both data & a cross-activation token, as the dependency pass tried to treat the state read as an I/O operation participating in the token graph.

PiperOrigin-RevId: 726656869
  • Loading branch information
ericastor authored and copybara-github committed Feb 13, 2025
1 parent 53f9b55 commit 533c748
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 36 deletions.
78 changes: 42 additions & 36 deletions xls/passes/token_dependency_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ absl::StatusOr<bool> TokenDependencyPass::RunOnFunctionBaseInternal(
return discovered;
};

// A relation mapping each effectful node to the set of receives that it is
// data-dependent on, but not token-dependent on.
NodeRelation io_to_receive;
// A relation mapping each effectful node to the set of I/O operations that it
// is data-dependent on, but not token-dependent on.
NodeRelation io_to_data_supplying_io;

auto is_side_effecting_token_op = [](Node* n) {
return OpIsSideEffecting(n->op()) && TypeHasToken(n->GetType());
Expand All @@ -133,6 +133,10 @@ absl::StatusOr<bool> TokenDependencyPass::RunOnFunctionBaseInternal(
if (a->GetType()->IsToken()) {
continue;
}
if (a->Is<StateRead>()) {
// This operation does not consume a token.
continue;
}
if (a->op() != Op::kReceive) {
return absl::InternalError(
"Can't handle token-and-data producing ops other than receive yet");
Expand All @@ -149,25 +153,25 @@ absl::StatusOr<bool> TokenDependencyPass::RunOnFunctionBaseInternal(
continue;
}

io_to_receive[b].insert(a);
io_to_data_supplying_io[b].insert(a);
}
}

VLOG(3) << "IO to receive:";
XLS_VLOG_LINES(3, relation_to_string(io_to_receive));
VLOG(3) << "IO to data-supplying IO:";
XLS_VLOG_LINES(3, relation_to_string(io_to_data_supplying_io));

// A relation similar to `io_to_receive`, except that receives are only
// included at the earliest points where they have an effect. For example, if
// `C` is token-dependent on both `A` and `B`, and `io_to_receive` contains
// all of `A`, `B`, and `C`, with
// A relation similar to `io_to_data_supplying_io`, except that receives are
// only included at the earliest points where they have an effect. For
// example, if `C` is token-dependent on both `A` and `B`, and
// `io_to_data_supplying_io` contains all of `A`, `B`, and `C`, with
//
// - `io_to_receive[A]` containing `recv1`,
// - `io_to_receive[B]` containing `recv2`,
// - `io_to_receive[C]` containing `recv1`, `recv2`, and `recv3`,
// - `io_to_data_supplying_io[A]` containing `recv1`,
// - `io_to_data_supplying_io[B]` containing `recv2`,
// - `io_to_data_supplying_io[C]` containing `recv1`, `recv2`, and `recv3`,
//
// then `minimal_io_to_receive[C]` will only include `recv3`.
NodeRelation minimal_io_to_receive = io_to_receive;
for (const auto& [io, receives] : io_to_receive) {
// then `minimal_io_to_data_supplying_io[C]` will only include `recv3`.
NodeRelation minimal_io_to_data_supplying_io = io_to_data_supplying_io;
for (const auto& [io, supplying_ios] : io_to_data_supplying_io) {
auto it = token_deps_closure.find(io);
if (it == token_deps_closure.end()) {
continue;
Expand All @@ -177,46 +181,48 @@ absl::StatusOr<bool> TokenDependencyPass::RunOnFunctionBaseInternal(
continue;
}

auto it = minimal_io_to_receive.find(downstream_of_io);
if (it == minimal_io_to_receive.end()) {
auto it = minimal_io_to_data_supplying_io.find(downstream_of_io);
if (it == minimal_io_to_data_supplying_io.end()) {
continue;
}
absl::flat_hash_set<Node*>& downstream_receives = it->second;
for (Node* receive : receives) {
downstream_receives.erase(receive);
absl::flat_hash_set<Node*>& downstream_ios = it->second;
for (Node* supplying_io : supplying_ios) {
downstream_ios.erase(supplying_io);
}
if (downstream_receives.empty()) {
minimal_io_to_receive.erase(it);
if (downstream_ios.empty()) {
minimal_io_to_data_supplying_io.erase(it);
}
}
}
VLOG(3) << "Minimal IO to receive:";
XLS_VLOG_LINES(3, relation_to_string(minimal_io_to_receive));
VLOG(3) << "Minimal IO to data-supplying IO:";
XLS_VLOG_LINES(3, relation_to_string(minimal_io_to_data_supplying_io));

bool changed = false;

// Before touching the IR create a deterministic sort of the keys of the
// relation.
std::vector<Node*> minimal_io_to_receive_keys;
minimal_io_to_receive_keys.reserve(minimal_io_to_receive.size());
for (const auto& [io, _] : minimal_io_to_receive) {
minimal_io_to_receive_keys.push_back(io);
std::vector<Node*> minimal_io_to_data_supplying_io_keys;
minimal_io_to_data_supplying_io_keys.reserve(
minimal_io_to_data_supplying_io.size());
for (const auto& [io, _] : minimal_io_to_data_supplying_io) {
minimal_io_to_data_supplying_io_keys.push_back(io);
}
SortByNodeId(&minimal_io_to_receive_keys);
SortByNodeId(&minimal_io_to_data_supplying_io_keys);

for (Node* io : minimal_io_to_receive_keys) {
for (Node* receive : SetToSortedVector(minimal_io_to_receive.at(io))) {
for (Node* io : minimal_io_to_data_supplying_io_keys) {
for (Node* supplying_io :
SetToSortedVector(minimal_io_to_data_supplying_io.at(io))) {
for (Node* input : io->operands()) {
// We're making the assumption that any token-typed input to an
// effectful operation must be a proper token input.
if (input->GetType()->IsToken()) {
XLS_ASSIGN_OR_RETURN(
Node * receive_token,
f->MakeNode<TupleIndex>(SourceInfo(), receive, 0));
Node * supplying_token,
f->MakeNode<TupleIndex>(SourceInfo(), supplying_io, 0));
XLS_ASSIGN_OR_RETURN(
Node * new_token,
f->MakeNode<AfterAll>(SourceInfo(),
std::vector<Node*>{receive_token, input}));
f->MakeNode<AfterAll>(
SourceInfo(), std::vector<Node*>{supplying_token, input}));
bool operand_replaced = io->ReplaceOperand(input, new_token);
changed = changed || operand_replaced;
}
Expand Down
26 changes: 26 additions & 0 deletions xls/passes/token_dependency_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,31 @@ TEST_F(TokenDependencyPassTest, SideEffectingNontokenOps) {
EXPECT_THAT(Run(proc), IsOkAndHolds(false));
}

TEST_F(TokenDependencyPassTest, SupportsCrossActivationTokens) {
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Package> p, ParsePackage(R"(
package test_module
chan test_channel(
bits[32], id=0, kind=streaming, ops=send_only,
flow_control=ready_valid, metadata="""""")
top proc main(__state: (bits[32], token, bits[32]), init={(1, token, 1)}) {
a: bits[32] = tuple_index(__state, index=0)
b: bits[32] = tuple_index(__state, index=2)
tok: token = tuple_index(__state, index=1)
c: bits[32] = add(a, b)
new_tok: token = literal(value=token)
snd: token = send(new_tok, c, channel=test_channel)
next_tok: token = after_all(tok, snd)
next_state: (bits[32], token, bits[32]) = tuple(b, next_tok, c)
next (next_state)
}
)"));
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, p->GetTopAsProc());
// No changes required; in particular, we don't need the send to depend on the
// cross-activation token as written.
EXPECT_THAT(Run(proc), IsOkAndHolds(false));
}

} // namespace
} // namespace xls

0 comments on commit 533c748

Please sign in to comment.