Skip to content

Commit

Permalink
PR #19116: [XLA:CPU] [oneDNN] Refactoring oneDNN Memory Util for Cust…
Browse files Browse the repository at this point in the history
…om Call oneDNN Thunk Runtime Support

Imported from GitHub PR #19116

At thunk execution, the memory buffer info for oneDNN is created based on the shapes of the input arguments and the output results. This PR refactors the `onednn_memory_util` to create memory references from shape, which will be used in a separate PR to add custom call oneDNN thunk support.
Copybara import of the project:

--
8664ea2 by Om Thakkar <om.thakkar@intel.com>:

onednn_memory_util refactoring for thunk support

Merging this change closes #19116

COPYBARA_INTEGRATE_REVIEW=#19116 from Intel-tensorflow:othakkar/memref_onednn_mem_util 3389b1a
PiperOrigin-RevId: 697001291
  • Loading branch information
othakkar authored and Google-ML-Automation committed Nov 15, 2024
1 parent e58b14e commit 4d5f691
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
14 changes: 8 additions & 6 deletions xla/service/cpu/onednn_memory_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ struct MemrefInfoPOD {
void* data;
};

MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal) {
MemrefInfoHandler CreateMemrefFromShape(const Shape& shape, void* const buf) {
MemrefInfoHandler result(new MemrefInfoPOD);

const auto& shape = literal->shape();
result->dtype = shape.element_type();
result->rank = shape.rank();
auto dimensions = shape.dimensions();
Expand All @@ -65,12 +63,16 @@ MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal) {
result->strides[i] = stride;
stride *= dimensions.at(i);
}

result->data = const_cast<void*>(literal->untyped_data());

result->data = buf;
return result;
}

MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal) {
const auto& shape = literal->shape();
void* const buf = const_cast<void*>(literal->untyped_data());
return CreateMemrefFromShape(shape, buf);
}

StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilderBase& builder,
const llvm_ir::IrArray& ir_array) {
const Shape& shape = ir_array.GetShape();
Expand Down
4 changes: 3 additions & 1 deletion xla/service/cpu/onednn_memory_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ using MemrefInfoHandler = std::shared_ptr<MemrefInfoPOD>;

MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal);

MemrefInfoHandler CreateMemrefFromShape(const Shape& shape, void* buf);

StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilderBase& builder,
const llvm_ir::IrArray& ir_array);

Expand Down Expand Up @@ -102,7 +104,7 @@ inline PrimitiveType ToXlaPrimitiveType(dnnl::memory::data_type dtype) {

class MemrefInfo {
public:
MemrefInfo(void* data);
explicit MemrefInfo(void* pod_data);

dnnl::memory::dims GetOneDnnDims() const;
dnnl::memory::dims GetOneDnnStrides() const;
Expand Down

0 comments on commit 4d5f691

Please sign in to comment.