Skip to content

Commit

Permalink
intermediate check in
Browse files Browse the repository at this point in the history
  • Loading branch information
lialan committed Dec 19, 2024
1 parent 3099d0b commit d80ef3a
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 2 deletions.
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,9 @@ RankedTensorType dropEncoding(RankedTensorType type) {
return RankedTensorType::get(type.getShape(), type.getElementType());
}

RankedTensorType dropPackedStorageEncodingIfAny(RankedTensorType type) {
if (!IREE::Encoding::hasPackedStorageAttr(type)) return type;
return RankedTensorType::get(type.getShape(), type.getElementType());
}

} // namespace mlir::iree_compiler
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -79,6 +80,9 @@ class OpMaterializeEncodingPattern : public OpConversionPattern<OpTy> {
/// Returns the RankedTensorType without encodings.
RankedTensorType dropEncoding(RankedTensorType type);

/// Returns the RankedTensorType without packed storage encoding (if any).
RankedTensorType dropPackedStorageEncodingIfAny(RankedTensorType type);

/// Utility method to convert from `set_encoding` op to `pack` operation.
/// NOTE: `source` could be returned when packing is not needed.
FailureOr<Value> lowerSetEncodingOpToPackOp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.h"

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/HAL/Analysis/Captures.h"
#include "iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
Expand Down Expand Up @@ -478,7 +480,8 @@ struct TensorExportBufferViewOpPattern
}

auto loc = exportOp.getLoc();
auto tensorType = llvm::cast<RankedTensorType>(adaptor.getSourceEncoding());
auto tensorType = dropPackedStorageEncodingIfAny(
llvm::cast<RankedTensorType>(adaptor.getSourceEncoding()));
auto dynamicDims = adaptor.getSourceEncodingDims();

// NOTE: we should have verified supported encodings/types at entry into the
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Util/IR/ClosureOpUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
Expand All @@ -27,6 +28,10 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/RegionUtils.h"

namespace mlir::iree_compiler {
using IREE::Encoding::getEncodingAttr;
}

namespace mlir::iree_compiler::IREE::Stream {

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1512,7 +1517,7 @@ LogicalResult TensorCloneOp::verify() {
// information.
auto sourceEncoding = llvm::cast<RankedTensorType>(op.getSourceEncoding());
auto resultEncoding = llvm::cast<RankedTensorType>(op.getResultEncoding());
if (sourceEncoding.getEncoding() != resultEncoding.getEncoding()) {
if (getEncodingAttr(sourceEncoding) != getEncodingAttr(resultEncoding)) {
return op.emitOpError() << "clones changing tensor encoding from "
<< sourceEncoding.getEncoding() << " to "
<< resultEncoding.getEncoding() << "; not allowed";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
Expand All @@ -22,6 +23,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -247,6 +249,12 @@ struct ConvertToStreamPass final
if (llvm::isa<IREE::Flow::ChannelType>(type)) {
return IREE::Stream::ChannelType::get(context);
}
if (auto rankedType = llvm::dyn_cast_or_null<RankedTensorType>(type)) {
if (IREE::Encoding::hasPackedStorageAttr(rankedType)) {
return RankedTensorType::get(rankedType.getShape(),
rankedType.getElementType());
}
}
return !llvm::isa<TensorType>(type) ? type : Type{};
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,10 @@ func.func @aligned_i1_size() -> index {
// CHECK: func @aligned_i1_size() -> index {
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: return %[[C3]] : index

// -----

#packed = #iree_encoding.packed_storage
func.func @packed_i1_input_output(%input : tensor<16xi1, #packed>) -> tensor<16xi1, #packed> {
return %input : tensor<16xi1, #packed>
}
15 changes: 15 additions & 0 deletions tests/e2e/subbyte_types/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,24 @@ iree_check_single_backend_test_suite(
target_backend = "llvm-cpu",
)

iree_check_single_backend_test_suite(
name = "check_llvm-cpu_subbyte_emulation_attr",
srcs = ["subbyte_types_attr.mlir"],
compiler_flags = [
"--iree-llvmcpu-target-cpu=generic",
],
driver = "local-task",
tags = [
# subbyte support for wasm is not on priorities.
"nowasm",
],
target_backend = "llvm-cpu",
)

test_suite(
name = "check",
tests = [
":check_llvm-cpu_subbyte_emulation",
":check_llvm-cpu_subbyte_emulation_attr",
],
)
15 changes: 15 additions & 0 deletions tests/e2e/subbyte_types/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,19 @@ iree_check_single_backend_test_suite(
"nowasm"
)

iree_check_single_backend_test_suite(
NAME
check_llvm-cpu_subbyte_emulation_attr
SRCS
"subbyte_types_attr.mlir"
TARGET_BACKEND
"llvm-cpu"
DRIVER
"local-task"
COMPILER_FLAGS
"--iree-llvmcpu-target-cpu=generic"
LABELS
"nowasm"
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###

0 comments on commit d80ef3a

Please sign in to comment.