Skip to content

Commit

Permalink
feat(torch): add JIT FusionStrategy selection
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and beniz committed Nov 8, 2024
1 parent 3afac74 commit d2331be
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <fcntl.h>
#include <torch/csrc/jit/runtime/graph_executor.h>

#include "native/native.h"
#include "torchsolver.h"
Expand Down Expand Up @@ -188,9 +189,20 @@ namespace dd
void TorchLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::init_mllib(const APIData &lib_ad)
{
// Get parameters
auto mllib_dto = lib_ad.createSharedDTO<DTO::MLLib>();

// experimental: jit compiler parameters
if (mllib_dto->jit_compiler_params)
{
auto jit_params = mllib_dto->jit_compiler_params;
torch::jit::FusionStrategy strat = {
{ torch::jit::FusionBehavior::STATIC, jit_params->fusion_static },
{ torch::jit::FusionBehavior::DYNAMIC, jit_params->fusion_dynamic }
};
auto old_strat = torch::jit::setFusionStrategy(strat);
}

// publish model (from existing repository)
if (mllib_dto->from_repository != nullptr)
{
this->_mlmodel.copy_to_target(mllib_dto->from_repository,
Expand Down
42 changes: 42 additions & 0 deletions src/dto/mllib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,38 @@ namespace dd
DTO_FIELD(Int32, test_batch_size) = 1;
};

// Torch JIT compiler
/** This parameters are set using the function setFusionStrategy:
* https://github.com/pytorch/pytorch/blob/4c074a9b8bd2e6d8940b40a41ce399e6c4a463a9/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp#L131
*/
class JitCompilerParameters : public oatpp::DTO
{
DTO_INIT(JitCompilerParameters, DTO)

DTO_FIELD_INFO(fusion_static)
{
info->description
= "Set number of STATIC specializations that can occur when "
"compiling jit models. In STATIC fusion, fused ops are compiled "
"to have fixed shapes. See "
"https://pytorch.org/docs/stable/generated/"
"torch.jit.set_fusion_strategy.html for more details.";
}
DTO_FIELD(Int32, fusion_static) = 2;

DTO_FIELD_INFO(fusion_dynamic)
{
info->description
= "Set number of DYNAMIC specializations that can occur when "
"compiling jit models. In DYNAMIC fusion, fused ops are "
"compiled to have variable input shape, so that this "
"specialization can be used with multiple shapes. See "
"https://pytorch.org/docs/stable/generated/"
"torch.jit.set_fusion_strategy.html for more details";
}
DTO_FIELD(Int32, fusion_dynamic) = 10;
};

class MLLib : public oatpp::DTO
{
DTO_INIT(MLLib, DTO /* extends */)
Expand Down Expand Up @@ -191,6 +223,16 @@ namespace dd
}
DTO_FIELD(Boolean, concurrent_predict) = true;

DTO_FIELD_INFO(jit_compiler_params)
{
info->description
= "[experimental] Modify torch jit compiler parameters. This can "
"be useful when the model takes too long to compile, for "
"example. Beware, this parameter applies instantly for all dede "
"services.";
}
DTO_FIELD(Object<JitCompilerParameters>, jit_compiler_params);

// Libtorch predict options
DTO_FIELD_INFO(forward_method)
{
Expand Down

0 comments on commit d2331be

Please sign in to comment.