Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(torch): add JIT FusionStrategy selection #1561

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <fcntl.h>
#include <torch/csrc/jit/runtime/graph_executor.h>

#include "native/native.h"
#include "torchsolver.h"
Expand Down Expand Up @@ -188,9 +189,20 @@ namespace dd
void TorchLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::init_mllib(const APIData &lib_ad)
{
// Get parameters
auto mllib_dto = lib_ad.createSharedDTO<DTO::MLLib>();

// experimental: jit compiler parameters
if (mllib_dto->jit_compiler_params)
{
auto jit_params = mllib_dto->jit_compiler_params;
torch::jit::FusionStrategy strat = {
{ torch::jit::FusionBehavior::STATIC, jit_params->fusion_static },
{ torch::jit::FusionBehavior::DYNAMIC, jit_params->fusion_dynamic }
};
auto old_strat = torch::jit::setFusionStrategy(strat);
}

// publish model (from existing repository)
if (mllib_dto->from_repository != nullptr)
{
this->_mlmodel.copy_to_target(mllib_dto->from_repository,
Expand Down
42 changes: 42 additions & 0 deletions src/dto/mllib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,38 @@ namespace dd
DTO_FIELD(Int32, test_batch_size) = 1;
};

// Torch JIT compiler
/** This parameters are set using the function setFusionStrategy:
* https://github.com/pytorch/pytorch/blob/4c074a9b8bd2e6d8940b40a41ce399e6c4a463a9/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp#L131
*/
class JitCompilerParameters : public oatpp::DTO
{
DTO_INIT(JitCompilerParameters, DTO)

DTO_FIELD_INFO(fusion_static)
{
info->description
= "Set number of STATIC specializations that can occur when "
"compiling jit models. In STATIC fusion, fused ops are compiled "
"to have fixed shapes. See "
"https://pytorch.org/docs/stable/generated/"
"torch.jit.set_fusion_strategy.html for more details.";
}
DTO_FIELD(Int32, fusion_static) = 2;

DTO_FIELD_INFO(fusion_dynamic)
{
info->description
= "Set number of DYNAMIC specializations that can occur when "
"compiling jit models. In DYNAMIC fusion, fused ops are "
"compiled to have variable input shape, so that this "
"specialization can be used with multiple shapes. See "
"https://pytorch.org/docs/stable/generated/"
"torch.jit.set_fusion_strategy.html for more details";
}
DTO_FIELD(Int32, fusion_dynamic) = 10;
};

class MLLib : public oatpp::DTO
{
DTO_INIT(MLLib, DTO /* extends */)
Expand Down Expand Up @@ -191,6 +223,16 @@ namespace dd
}
DTO_FIELD(Boolean, concurrent_predict) = true;

DTO_FIELD_INFO(jit_compiler_params)
{
info->description
= "[experimental] Modify torch jit compiler parameters. This can "
"be useful when the model takes too long to compile, for "
"example. Beware, this parameter applies instantly for all dede "
"services.";
}
DTO_FIELD(Object<JitCompilerParameters>, jit_compiler_params);

// Libtorch predict options
DTO_FIELD_INFO(forward_method)
{
Expand Down