Skip to content

Commit

Permalink
Monomorphize dropped functions (#6734)
Browse files Browse the repository at this point in the history
We now consider a drop to be part of the call context: If we see

(drop
  (call $foo)
)

(func $foo (result i32)
  (i32.const 42)
)

Then we'd monomorphize to this:

(call $foo_1)  ;; call the specialized function instead

(func $foo_1   ;; the specialized function returns nothing
  (drop        ;; the drop was moved into here
    (i32.const 42)
  )
)

With the drop now in the called function, we may be able to optimize out unused work.

Refactor a bit of code out of DAE that we can reuse here, into a new return-utils.h.
  • Loading branch information
kripken committed Jul 12, 2024
1 parent 20c10df commit d2a48af
Show file tree
Hide file tree
Showing 6 changed files with 1,418 additions and 36 deletions.
1 change: 1 addition & 0 deletions src/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ set(ir_SOURCES
LocalGraph.cpp
LocalStructuralDominance.cpp
ReFinalize.cpp
return-utils.cpp
stack-utils.cpp
table-utils.cpp
type-updating.cpp
Expand Down
99 changes: 99 additions & 0 deletions src/ir/return-utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright 2024 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ir/return-utils.h"
#include "ir/module-utils.h"
#include "wasm-builder.h"
#include "wasm-traversal.h"
#include "wasm.h"

namespace wasm::ReturnUtils {

namespace {

struct ReturnValueRemover : public PostWalker<ReturnValueRemover> {
void visitReturn(Return* curr) {
auto* value = curr->value;
assert(value);
curr->value = nullptr;
Builder builder(*getModule());
replaceCurrent(builder.makeSequence(builder.makeDrop(value), curr));
}

void visitCall(Call* curr) { handleReturnCall(curr); }
void visitCallIndirect(CallIndirect* curr) { handleReturnCall(curr); }
void visitCallRef(CallRef* curr) { handleReturnCall(curr); }

template<typename T> void handleReturnCall(T* curr) {
if (curr->isReturn) {
Fatal() << "Cannot remove return_calls in ReturnValueRemover";
}
}

void visitFunction(Function* curr) {
if (curr->body->type.isConcrete()) {
curr->body = Builder(*getModule()).makeDrop(curr->body);
}
}
};

} // anonymous namespace

void removeReturns(Function* func, Module& wasm) {
ReturnValueRemover().walkFunctionInModule(func, &wasm);
}

std::unordered_map<Function*, bool> findReturnCallers(Module& wasm) {
ModuleUtils::ParallelFunctionAnalysis<bool> analysis(
wasm, [&](Function* func, bool& hasReturnCall) {
if (func->imported()) {
return;
}

struct Finder : PostWalker<Finder> {
bool hasReturnCall = false;

void visitCall(Call* curr) {
if (curr->isReturn) {
hasReturnCall = true;
}
}
void visitCallIndirect(CallIndirect* curr) {
if (curr->isReturn) {
hasReturnCall = true;
}
}
void visitCallRef(CallRef* curr) {
if (curr->isReturn) {
hasReturnCall = true;
}
}
} finder;

finder.walk(func->body);
hasReturnCall = finder.hasReturnCall;
});

// Convert to an unordered map for fast lookups. TODO: Avoid a copy here.
std::unordered_map<Function*, bool> ret;
ret.reserve(analysis.map.size());
for (auto& [k, v] : analysis.map) {
ret[k] = v;
}
return ret;
}

} // namespace wasm::ReturnUtils
39 changes: 39 additions & 0 deletions src/ir/return-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2024 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef wasm_ir_return_h
#define wasm_ir_return_h

#include "wasm.h"

namespace wasm::ReturnUtils {

// Removes values from both explicit returns and implicit ones (values that flow
// from the body). This is useful after changing a function's type to no longer
// return anything.
//
// This does *not* handle return calls, and will error on them. Removing a
// return call may change the semantics of the program, so we do not do it
// automatically here.
void removeReturns(Function* func, Module& wasm);

// Return a map of every function to whether it does a return call.
using ReturnCallersMap = std::unordered_map<Function*, bool>;
ReturnCallersMap findReturnCallers(Module& wasm);

} // namespace wasm::ReturnUtils

#endif // wasm_ir_return_h
19 changes: 2 additions & 17 deletions src/passes/DeadArgumentElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "ir/find_all.h"
#include "ir/lubs.h"
#include "ir/module-utils.h"
#include "ir/return-utils.h"
#include "ir/type-updating.h"
#include "ir/utils.h"
#include "param-utils.h"
Expand Down Expand Up @@ -358,23 +359,7 @@ struct DAE : public Pass {
}
}
// Remove any return values.
struct ReturnUpdater : public PostWalker<ReturnUpdater> {
Module* module;
ReturnUpdater(Function* func, Module* module) : module(module) {
walk(func->body);
}
void visitReturn(Return* curr) {
auto* value = curr->value;
assert(value);
curr->value = nullptr;
Builder builder(*module);
replaceCurrent(builder.makeSequence(builder.makeDrop(value), curr));
}
} returnUpdater(func, module);
// Remove any value flowing out.
if (func->body->type.isConcrete()) {
func->body = Builder(*module).makeDrop(func->body);
}
ReturnUtils::removeReturns(func, *module);
}

// Given a function and all the calls to it, see if we can refine the type of
Expand Down
105 changes: 86 additions & 19 deletions src/passes/Monomorphize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
#include "ir/manipulation.h"
#include "ir/module-utils.h"
#include "ir/names.h"
#include "ir/return-utils.h"
#include "ir/type-updating.h"
#include "ir/utils.h"
#include "pass.h"
Expand All @@ -103,6 +104,36 @@ namespace wasm {

namespace {

// Core information about a call: the call itself, and if it is dropped, the
// drop.
struct CallInfo {
Call* call;
// Store a reference to the drop's pointer so that we can replace it, as when
// we optimize a dropped call we need to replace (drop (call)) with (call).
// Or, if the call is not dropped, this is nullptr.
Expression** drop;
};

// Finds the calls and whether each one of them is dropped.
struct CallFinder : public PostWalker<CallFinder> {
std::vector<CallInfo> infos;

void visitCall(Call* curr) {
// Add the call as not having a drop, and update the drop later if we are.
infos.push_back(CallInfo{curr, nullptr});
}

void visitDrop(Drop* curr) {
if (curr->value->is<Call>()) {
// The call we just added to |infos| is dropped.
assert(!infos.empty());
auto& back = infos.back();
assert(back.call == curr->value);
back.drop = getCurrentPointer();
}
}
};

// Relevant information about a callsite for purposes of monomorphization.
struct CallContext {
// The operands of the call, processed to leave the parts that make sense to
Expand Down Expand Up @@ -181,12 +212,12 @@ struct CallContext {
// remaining values by updating |newOperands| (for example, if all the values
// sent are constants, then |newOperands| will end up empty, as we have
// nothing left to send).
void buildFromCall(Call* call,
void buildFromCall(CallInfo& info,
std::vector<Expression*>& newOperands,
Module& wasm) {
Builder builder(wasm);

for (auto* operand : call->operands) {
for (auto* operand : info.call->operands) {
// Process the operand. This is a copy operation, as we are trying to move
// (copy) code from the callsite into the called function. When we find we
// can copy then we do so, and when we cannot that value remains as a
Expand All @@ -212,8 +243,7 @@ struct CallContext {
}));
}

// TODO: handle drop
dropped = false;
dropped = !!info.drop;
}

// Checks whether an expression can be moved into the context.
Expand Down Expand Up @@ -299,6 +329,11 @@ struct Monomorphize : public Pass {
void run(Module* module) override {
// TODO: parallelize, see comments below

// Find all the return-calling functions. We cannot remove their returns
// (because turning a return call into a normal call may break the program
// by using more stack).
auto returnCallersMap = ReturnUtils::findReturnCallers(*module);

// Note the list of all functions. We'll be adding more, and do not want to
// operate on those.
std::vector<Name> funcNames;
Expand All @@ -309,26 +344,38 @@ struct Monomorphize : public Pass {
// to call the monomorphized targets.
for (auto name : funcNames) {
auto* func = module->getFunction(name);
for (auto* call : FindAll<Call>(func->body).list) {
if (call->type == Type::unreachable) {

CallFinder callFinder;
callFinder.walk(func->body);
for (auto& info : callFinder.infos) {
if (info.call->type == Type::unreachable) {
// Ignore unreachable code.
// TODO: return_call?
continue;
}

if (call->target == name) {
if (info.call->target == name) {
// Avoid recursion, which adds some complexity (as we'd be modifying
// ourselves if we apply optimizations).
continue;
}

processCall(call, *module);
// If the target function does a return call, then as noted earlier we
// cannot remove its returns, so do not consider the drop as part of the
// context in such cases (as if we reverse-inlined the drop into the
// target then we'd be removing the returns).
if (returnCallersMap[module->getFunction(info.call->target)]) {
info.drop = nullptr;
}

processCall(info, *module);
}
}
}

// Try to optimize a call.
void processCall(Call* call, Module& wasm) {
void processCall(CallInfo& info, Module& wasm) {
auto* call = info.call;
auto target = call->target;
auto* func = wasm.getFunction(target);
if (func->imported()) {
Expand All @@ -342,19 +389,16 @@ struct Monomorphize : public Pass {
// if we use that context.
CallContext context;
std::vector<Expression*> newOperands;
context.buildFromCall(call, newOperands, wasm);
context.buildFromCall(info, newOperands, wasm);

// See if we've already evaluated this function + call context. If so, then
// we've memoized the result.
auto iter = funcContextMap.find({target, context});
if (iter != funcContextMap.end()) {
auto newTarget = iter->second;
if (newTarget != target) {
// When we computed this before we found a benefit to optimizing, and
// created a new monomorphized function to call. Use it by simply
// applying the new operands we computed, and adjusting the call target.
call->operands.set(newOperands);
call->target = newTarget;
// We saw benefit to optimizing this case. Apply that.
updateCall(info, newTarget, newOperands, wasm);
}
return;
}
Expand Down Expand Up @@ -419,8 +463,7 @@ struct Monomorphize : public Pass {
if (worthwhile) {
// We are using the monomorphized function, so update the call and add it
// to the module.
call->operands.set(newOperands);
call->target = monoFunc->name;
updateCall(info, monoFunc->name, newOperands, wasm);

wasm.addFunction(std::move(monoFunc));
}
Expand Down Expand Up @@ -453,8 +496,9 @@ struct Monomorphize : public Pass {
newParams.push_back(operand->type);
}
}
// TODO: support changes to results.
auto newResults = func->getResults();
// If we were dropped then we are pulling the drop into the monomorphized
// function, which means we return nothing.
auto newResults = context.dropped ? Type::none : func->getResults();
newFunc->type = Signature(Type(newParams), newResults);

// We must update local indexes: the new function has a potentially
Expand Down Expand Up @@ -549,9 +593,32 @@ struct Monomorphize : public Pass {
newFunc->body = builder.makeBlock(pre);
}

if (context.dropped) {
ReturnUtils::removeReturns(newFunc.get(), wasm);
}

return newFunc;
}

// Given a call and a new target it should be calling, apply that new target,
// including updating the operands and handling dropping.
void updateCall(const CallInfo& info,
Name newTarget,
const std::vector<Expression*>& newOperands,
Module& wasm) {
info.call->target = newTarget;
info.call->operands.set(newOperands);

if (info.drop) {
// Replace (drop (call)) with (call), that is, replace the drop with the
// (updated) call which now has type none. Note we should have handled
// unreachability before getting here.
assert(info.call->type != Type::unreachable);
info.call->type = Type::none;
*info.drop = info.call;
}
}

// Run some function-level optimizations on a function. Ideally we would run a
// minimal amount of optimizations here, but we do want to give the optimizer
// as much of a chance to work as possible, so for now do all of -O3 (in
Expand Down
Loading

0 comments on commit d2a48af

Please sign in to comment.