Skip to content

Commit

Permalink
[SandboxVec] Add barebones Region class. (llvm#108899)
Browse files Browse the repository at this point in the history
A region identifies a set of vector instructions generated by
vectorization passes. The vectorizer can then run a series of
RegionPasses on the region, evaluate the cost, and commit/reject the
transforms on a region-by-region basis, instead of an entire basic
block.

This is heavily based ov @vporpo's prototype. In particular, the doc
comment for the Region class is all his. The rest of this commit is
mostly boilerplate around a SetVector: getters, iterators, and some
debug helpers.
  • Loading branch information
slackito committed Sep 17, 2024
1 parent 790f2eb commit 3aecf41
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 0 deletions.
106 changes: 106 additions & 0 deletions llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Region.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
//===- Region.h -------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_REGION_H
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_REGION_H

#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/SandboxIR/SandboxIR.h"
#include "llvm/Support/InstructionCost.h"
#include "llvm/Support/raw_ostream.h"

namespace llvm::sandboxir {

/// The main job of the Region is to point to new instructions generated by
/// vectorization passes. It is the unit that RegionPasses operate on with their
/// runOnRegion() function.
///
/// The region allows us to stack transformations horizontally, meaning that
/// each transformation operates on a single region and the resulting region is
/// the input to the next transformation, as opposed to vertically, which is the
/// common way of applying a transformation across the whole BB. This enables us
/// to check for profitability and decide whether we accept or rollback at a
/// region granularity, which is much better than doing this at the BB level.
///
// Traditional approach: transformations applied vertically for the whole BB
// BB
// +----+
// | |
// | |
// | | -> Transform1 -> ... -> TransformN -> Check Cost
// | |
// | |
// +----+
//
// Region-based approach: transformations applied horizontally, for each Region
// BB
// +----+
// |Rgn1| -> Transform1 -> ... -> TransformN -> Check Cost
// | |
// |Rgn2| -> Transform1 -> ... -> TransformN -> Check Cost
// | |
// |Rgn3| -> Transform1 -> ... -> TransformN -> Check Cost
// +----+

class Region {
/// All the instructions in the Region. Only new instructions generated during
/// vectorization are part of the Region.
SetVector<Instruction *> Insts;

/// A unique ID, used for debugging.
unsigned RegionID = 0;

Context &Ctx;

/// The basic block containing this region.
BasicBlock &BB;

// TODO: Add cost modeling.
// TODO: Add a way to encode/decode region info to/from metadata.

public:
Region(Context &Ctx, BasicBlock &BB);
~Region();

BasicBlock *getParent() const { return &BB; }
Context &getContext() const { return Ctx; }
/// Returns the region's unique ID.
unsigned getID() const { return RegionID; }

/// Adds I to the set.
void add(Instruction *I);
/// Removes I from the set.
void remove(Instruction *I);
/// Returns true if I is in the Region.
bool contains(Instruction *I) const { return Insts.contains(I); }
/// Returns true if the Region has no instructions.
bool empty() const { return Insts.empty(); }

using iterator = decltype(Insts.begin());
iterator begin() { return Insts.begin(); }
iterator end() { return Insts.end(); }
iterator_range<iterator> insts() { return make_range(begin(), end()); }

#ifndef NDEBUG
/// This is an expensive check, meant for testing.
bool operator==(const Region &Other) const;
bool operator!=(const Region &other) const { return !(*this == other); }

void dump(raw_ostream &OS) const;
void dump() const;
friend raw_ostream &operator<<(raw_ostream &OS, const Region &Rgn) {
Rgn.dump(OS);
return OS;
}
#endif
};

} // namespace llvm::sandboxir

#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_REGION_H
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_llvm_component_library(LLVMVectorize
LoopVectorize.cpp
SandboxVectorizer/DependencyGraph.cpp
SandboxVectorizer/Passes/BottomUpVec.cpp
SandboxVectorizer/Region.cpp
SandboxVectorizer/SandboxVectorizer.cpp
SLPVectorizer.cpp
Vectorize.cpp
Expand Down
46 changes: 46 additions & 0 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/Region.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//===- Region.cpp ---------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Vectorize/SandboxVectorizer/Region.h"

namespace llvm::sandboxir {

Region::Region(Context &Ctx, BasicBlock &BB) : Ctx(Ctx), BB(BB) {
static unsigned StaticRegionID;
RegionID = StaticRegionID++;
}

Region::~Region() {}

void Region::add(Instruction *I) { Insts.insert(I); }

void Region::remove(Instruction *I) { Insts.remove(I); }

#ifndef NDEBUG
bool Region::operator==(const Region &Other) const {
if (Insts.size() != Other.Insts.size())
return false;
if (!std::is_permutation(Insts.begin(), Insts.end(), Other.Insts.begin()))
return false;
return true;
}

void Region::dump(raw_ostream &OS) const {
OS << "RegionID: " << getID() << "\n";
for (auto *I : Insts)
OS << *I << "\n";
}

void Region::dump() const {
dump(dbgs());
dbgs() << "\n";
}

} // namespace llvm::sandboxir

#endif // NDEBUG
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ set(LLVM_LINK_COMPONENTS

add_llvm_unittest(SandboxVectorizerTests
DependencyGraphTest.cpp
RegionTest.cpp
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//===- RegionTest.cpp -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Vectorize/SandboxVectorizer/Region.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/SandboxIR/SandboxIR.h"
#include "llvm/Support/SourceMgr.h"
#include "gmock/gmock-matchers.h"
#include "gtest/gtest.h"

using namespace llvm;

struct RegionTest : public testing::Test {
LLVMContext C;
std::unique_ptr<Module> M;

void parseIR(LLVMContext &C, const char *IR) {
SMDiagnostic Err;
M = parseAssemblyString(IR, Err, C);
if (!M)
Err.print("RegionTest", errs());
}
};

TEST_F(RegionTest, Basic) {
parseIR(C, R"IR(
define i8 @foo(i8 %v0, i8 %v1) {
%t0 = add i8 %v0, 1
%t1 = add i8 %t0, %v1
ret i8 %t1
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *T0 = cast<sandboxir::Instruction>(&*It++);
auto *T1 = cast<sandboxir::Instruction>(&*It++);
auto *Ret = cast<sandboxir::Instruction>(&*It++);
sandboxir::Region Rgn(Ctx, *BB);

// Check getters
EXPECT_EQ(BB, Rgn.getParent());
EXPECT_EQ(&Ctx, &Rgn.getContext());
EXPECT_EQ(0U, Rgn.getID());

// Check add / remove / empty.
EXPECT_TRUE(Rgn.empty());
Rgn.add(T0);
EXPECT_FALSE(Rgn.empty());
Rgn.remove(T0);
EXPECT_TRUE(Rgn.empty());

// Check iteration.
Rgn.add(T0);
Rgn.add(T1);
Rgn.add(Ret);
// Use an ordered matcher because we're supposed to preserve the insertion
// order for determinism.
EXPECT_THAT(Rgn.insts(), testing::ElementsAre(T0, T1, Ret));

// Check contains
EXPECT_TRUE(Rgn.contains(T0));
Rgn.remove(T0);
EXPECT_FALSE(Rgn.contains(T0));

#ifndef NDEBUG
// Check equality comparison. Insert in reverse order into `Other` to check
// that comparison is order-independent.
sandboxir::Region Other(Ctx, *BB);
Other.add(Ret);
EXPECT_NE(Rgn, Other);
Other.add(T1);
EXPECT_EQ(Rgn, Other);
#endif
}

0 comments on commit 3aecf41

Please sign in to comment.