Skip to content

Commit

Permalink
Merge pull request #149 from NotCompsky/main
Browse files Browse the repository at this point in the history
Continue training from a PLY file
  • Loading branch information
pierotofy authored Jan 17, 2025
2 parents 2d17327 + 2d48f84 commit dfa162a
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 36 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,14 @@ To generate compressed splats (.splat files), use the `-o` option:
./opensplat /path/to/banana -o banana.splat
```

### Resume

You can resume training of a .PLY file by using the `--resume` option:

```bash
./opensplat /path/to/banana --resume ./splat.ply
```

### AMD GPU Notes

To train a model with AMD GPU using docker container, you can use the following command as a reference:
Expand Down
188 changes: 182 additions & 6 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "rasterize_gaussians.hpp"
#include "tensor_math.hpp"
#include "gsplat.hpp"
#include "utils.hpp"

#ifdef USE_HIP
#include <c10/hip/HIPCachingAllocator.h>
Expand Down Expand Up @@ -50,6 +51,30 @@ torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt){
return torch::abs(gt - rendered).mean();
}

void Model::setupOptimizers(){
releaseOptimizers();

meansOpt = new torch::optim::Adam({means}, torch::optim::AdamOptions(0.00016));
scalesOpt = new torch::optim::Adam({scales}, torch::optim::AdamOptions(0.005));
quatsOpt = new torch::optim::Adam({quats}, torch::optim::AdamOptions(0.001));
featuresDcOpt = new torch::optim::Adam({featuresDc}, torch::optim::AdamOptions(0.0025));
featuresRestOpt = new torch::optim::Adam({featuresRest}, torch::optim::AdamOptions(0.000125));
opacitiesOpt = new torch::optim::Adam({opacities}, torch::optim::AdamOptions(0.05));

meansOptScheduler = new OptimScheduler(meansOpt, 0.0000016f, maxSteps);
}

void Model::releaseOptimizers(){
RELEASE_SAFELY(meansOpt);
RELEASE_SAFELY(scalesOpt);
RELEASE_SAFELY(quatsOpt);
RELEASE_SAFELY(featuresDcOpt);
RELEASE_SAFELY(featuresRestOpt);
RELEASE_SAFELY(opacitiesOpt);

RELEASE_SAFELY(meansOptScheduler);
}


torch::Tensor Model::forward(Camera& cam, int step){

Expand Down Expand Up @@ -460,22 +485,22 @@ void Model::afterTrain(int step){
}
}

void Model::save(const std::string &filename){
void Model::save(const std::string &filename, int step){
if (fs::path(filename).extension().string() == ".splat"){
saveSplat(filename);
}else{
savePly(filename);
savePly(filename, step);
}
std::cout << "Wrote " << filename << std::endl;
}

void Model::savePly(const std::string &filename){
void Model::savePly(const std::string &filename, int step){
std::ofstream o(filename, std::ios::binary);
int numPoints = means.size(0);

o << "ply" << std::endl;
o << "format binary_little_endian 1.0" << std::endl;
o << "comment Generated by opensplat" << std::endl;
o << "comment Generated by opensplat at iteration " << step << std::endl;
o << "element vertex " << numPoints << std::endl;
o << "property float x" << std::endl;
o << "property float y" << std::endl;
Expand Down Expand Up @@ -564,14 +589,14 @@ void Model::saveSplat(const std::string &filename){
o.close();
}

void Model::saveDebugPly(const std::string &filename){
void Model::saveDebugPly(const std::string &filename, int step){
// A standard PLY
std::ofstream o(filename, std::ios::binary);
int numPoints = means.size(0);

o << "ply" << std::endl;
o << "format binary_little_endian 1.0" << std::endl;
o << "comment Generated by opensplat" << std::endl;
o << "comment Generated by opensplat at iteration " << step << std::endl;
o << "element vertex " << numPoints << std::endl;
o << "property float x" << std::endl;
o << "property float y" << std::endl;
Expand All @@ -593,6 +618,157 @@ void Model::saveDebugPly(const std::string &filename){
std::cout << "Wrote " << filename << std::endl;
}

int Model::loadPly(const std::string &filename){
std::ifstream f(filename, std::ios::binary);
if (!f.is_open()) throw std::runtime_error("Invalid PLY file");

// Ensure we have a valid ply file
std::string line;
int numPoints;
int step;
size_t bytesRead = 0;

std::getline(f, line);
bytesRead += f.gcount();

if (line == "ply"){
std::getline(f, line);
bytesRead += f.gcount();
if (line == "format binary_little_endian 1.0"){
std::getline(f, line);
bytesRead += f.gcount();
const std::string pattern = "comment Generated by opensplat at iteration ";

if (line.rfind(pattern, 0) == 0){
step = std::stoi(line.substr(pattern.length()));
if (step >= 0){
std::getline(f, line);
bytesRead += f.gcount();
const std::string pattern = "element vertex ";

if (line.rfind(pattern, 0) == 0){
const int numPoints = std::stoi(line.substr(pattern.length()));

const char *requiredProps[] = {
"property float x",
"property float y",
"property float z",
"property float nx",
"property float ny",
"property float nz",
"property float f_dc_"
"property float f_rest_",
"property float opacity",
"property float scale_0",
"property float scale_1",
"property float scale_2",
"property float rot_0",
"property float rot_1",
"property float rot_2",
"property float rot_3",
"end_header"
};

for (int i = 0; i < 6; i++){
std::getline(f, line);
bytesRead += f.gcount();
if (line != requiredProps[i]){
throw std::runtime_error(std::string("PLY file's header does not contain required property: ") + requiredProps[i]);
}
}
std::getline(f, line);
bytesRead += f.gcount();

auto countPrefixes = [&f, &line](const char *prefix){
int n = 0;
while(true){
if (line.rfind(prefix, 0) == 0){
++n;
std::getline(f, line);
} else {
break;
}
}
return n;
};
int featuresDcSize = countPrefixes("property float f_dc_");
int featuresRestSize = countPrefixes("property float f_rest_");

bool foundEnd = false;
for (int i = 8; i < std::size(requiredProps); i++){
std::getline(f, line);
bytesRead += f.gcount();

if (line != requiredProps[i]){
throw std::runtime_error(std::string("PLY file's header does not contain required property: ") + requiredProps[i]);
}

if (line == "end_header"){
foundEnd = true;
break;
}
}

if (!foundEnd){
throw std::runtime_error("PLY file header does not contain header end");
}

const size_t bytesPerPoint = sizeof(float) * (14 + featuresDcSize + featuresRestSize);
const size_t remainingFileSize = fs::file_size(filename) - bytesRead;
if (remainingFileSize != bytesPerPoint * numPoints){
std::cout << "Loading PLY..." << std::endl;

float zeros[3];

torch::Tensor meansCpu = torch::zeros({numPoints, 3}, torch::TensorOptions().dtype(torch::kFloat32));
torch::Tensor featuresDcCpu = torch::zeros({numPoints, featuresDcSize}, torch::TensorOptions().dtype(torch::kFloat32));
torch::Tensor featuresRestCpu = torch::zeros({numPoints, featuresRestSize}, torch::TensorOptions().dtype(torch::kFloat32));
torch::Tensor opacitiesCpu = torch::zeros({numPoints, 1}, torch::TensorOptions().dtype(torch::kFloat32));
torch::Tensor scalesCpu = torch::zeros({numPoints, 3}, torch::TensorOptions().dtype(torch::kFloat32));
torch::Tensor quatsCpu = torch::zeros({numPoints, 4}, torch::TensorOptions().dtype(torch::kFloat32));

for (size_t i = 0; i < numPoints; i++){
f.read(reinterpret_cast<char *>(meansCpu[i].data_ptr()), sizeof(float) * 3);
f.read(reinterpret_cast<char *>(&zeros[0]), sizeof(float) * 3);
f.read(reinterpret_cast<char *>(featuresDcCpu[i].data_ptr()), sizeof(float) * featuresDcSize);
f.read(reinterpret_cast<char *>(featuresRestCpu[i].data_ptr()), sizeof(float) * featuresRestSize);
f.read(reinterpret_cast<char *>(opacitiesCpu[i].data_ptr()), sizeof(float) * 1);
f.read(reinterpret_cast<char *>(scalesCpu[i].data_ptr()), sizeof(float) * 3);
f.read(reinterpret_cast<char *>(quatsCpu[i].data_ptr()), sizeof(float) * 4);
}
if (keepCrs){
meansCpu = (meansCpu - translation) * scale;
scalesCpu = torch::log(scale * torch::exp(scalesCpu));
}

means = meansCpu.to(device).requires_grad_();
featuresDc = featuresDcCpu.to(device).requires_grad_();
featuresRest = featuresRestCpu.reshape({numPoints, 3, featuresRestSize/3}).transpose(2, 1).to(device).requires_grad_();
opacities = opacitiesCpu.to(device).requires_grad_();
scales = scalesCpu.to(device).requires_grad_();
quats = quatsCpu.to(device).requires_grad_();

std::cerr << "Loaded " << means.size(0) << " gaussians" << std::endl;

setupOptimizers();

f.close();
return step;
} else {
throw std::runtime_error("PLY file's data section is wrong size");
}
}
} else {
throw std::runtime_error("PLY file failed sanity check: iteration count should not begin at 0");
}
} else if (line.rfind("comment Generated by opensplat")){
throw std::runtime_error("PLY file does not contain iteration count metadata. You can edit the file to add this metadata manually, by changing \"comment Generated by opensplat\" to \"comment Generated by opensplat at iteration 12345\", changing 12345 to the desired value.");
}
}
}
throw std::runtime_error("Invalid PLY file");
}

torch::Tensor Model::mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight){
torch::Tensor ssimLoss = 1.0f - ssim.eval(rgb, gt);
torch::Tensor l1Loss = l1(rgb, gt);
Expand Down
42 changes: 16 additions & 26 deletions model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,37 +54,27 @@ struct Model{
// backgroundColor = torch::tensor({0.0f, 0.0f, 0.0f}, device); // Black
backgroundColor = torch::tensor({0.6130f, 0.0101f, 0.3984f}, device); // Nerf Studio default

meansOpt = new torch::optim::Adam({means}, torch::optim::AdamOptions(0.00016));
scalesOpt = new torch::optim::Adam({scales}, torch::optim::AdamOptions(0.005));
quatsOpt = new torch::optim::Adam({quats}, torch::optim::AdamOptions(0.001));
featuresDcOpt = new torch::optim::Adam({featuresDc}, torch::optim::AdamOptions(0.0025));
featuresRestOpt = new torch::optim::Adam({featuresRest}, torch::optim::AdamOptions(0.000125));
opacitiesOpt = new torch::optim::Adam({opacities}, torch::optim::AdamOptions(0.05));

meansOptScheduler = new OptimScheduler(meansOpt, 0.0000016f, maxSteps);
setupOptimizers();
}

~Model(){
delete meansOpt;
delete scalesOpt;
delete quatsOpt;
delete featuresDcOpt;
delete featuresRestOpt;
delete opacitiesOpt;

delete meansOptScheduler;
releaseOptimizers();
}

void setupOptimizers();
void releaseOptimizers();

torch::Tensor forward(Camera& cam, int step);
void optimizersZeroGrad();
void optimizersStep();
void schedulersStep(int step);
int getDownscaleFactor(int step);
void afterTrain(int step);
void save(const std::string &filename);
void savePly(const std::string &filename);
void save(const std::string &filename, int step);
void savePly(const std::string &filename, int step);
void saveSplat(const std::string &filename);
void saveDebugPly(const std::string &filename);
void saveDebugPly(const std::string &filename, int step);
int loadPly(const std::string &filename);
torch::Tensor mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight);

void addToOptimizer(torch::optim::Adam *optimizer, const torch::Tensor &newParam, const torch::Tensor &idcs, int nSamples);
Expand All @@ -96,14 +86,14 @@ struct Model{
torch::Tensor featuresRest;
torch::Tensor opacities;

torch::optim::Adam *meansOpt;
torch::optim::Adam *scalesOpt;
torch::optim::Adam *quatsOpt;
torch::optim::Adam *featuresDcOpt;
torch::optim::Adam *featuresRestOpt;
torch::optim::Adam *opacitiesOpt;
torch::optim::Adam *meansOpt = nullptr;
torch::optim::Adam *scalesOpt = nullptr;
torch::optim::Adam *quatsOpt = nullptr;
torch::optim::Adam *featuresDcOpt = nullptr;
torch::optim::Adam *featuresRestOpt = nullptr;
torch::optim::Adam *opacitiesOpt = nullptr;

OptimScheduler *meansOptScheduler;
OptimScheduler *meansOptScheduler = nullptr;

torch::Tensor radii; // set in forward()
torch::Tensor xys; // set in forward()
Expand Down
16 changes: 12 additions & 4 deletions opensplat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ int main(int argc, char *argv[]){
("i,input", "Path to nerfstudio project", cxxopts::value<std::string>())
("o,output", "Path where to save output scene", cxxopts::value<std::string>()->default_value("splat.ply"))
("s,save-every", "Save output scene every these many steps (set to -1 to disable)", cxxopts::value<int>()->default_value("-1"))
("resume", "Resume training from this PLY file", cxxopts::value<std::string>()->default_value(""))
("val", "Withhold a camera shot for validating the scene loss")
("val-image", "Filename of the image to withhold for validating scene loss", cxxopts::value<std::string>()->default_value("random"))
("val-render", "Path of the directory where to render validation images", cxxopts::value<std::string>()->default_value(""))
Expand Down Expand Up @@ -69,6 +70,7 @@ int main(int argc, char *argv[]){
const std::string projectRoot = result["input"].as<std::string>();
const std::string outputScene = result["output"].as<std::string>();
const int saveEvery = result["save-every"].as<int>();
const std::string resume = result["resume"].as<std::string>();
const bool validate = result.count("val") > 0 || result.count("val-render") > 0;
const std::string valImage = result["val-image"].as<std::string>();
const std::string valRender = result["val-render"].as<std::string>();
Expand Down Expand Up @@ -132,7 +134,13 @@ int main(int argc, char *argv[]){
InfiniteRandomIterator<size_t> camsIter( camIndices );

int imageSize = -1;
for (size_t step = 1; step <= numIters; step++){
size_t step = 1;

if (resume != ""){
step = model.loadPly(resume) + 1;
}

for (; step <= numIters; step++){
Camera& cam = cams[ camsIter.next() ];

model.optimizersZeroGrad();
Expand All @@ -152,7 +160,7 @@ int main(int argc, char *argv[]){

if (saveEvery > 0 && step % saveEvery == 0){
fs::path p(outputScene);
model.save((p.replace_filename(fs::path(p.stem().string() + "_" + std::to_string(step) + p.extension().string())).string()));
model.save(p.replace_filename(fs::path(p.stem().string() + "_" + std::to_string(step) + p.extension().string())).string(), step);
}

if (!valRender.empty() && step % 10 == 0){
Expand All @@ -173,8 +181,8 @@ int main(int argc, char *argv[]){
}

inputData.saveCameras((fs::path(outputScene).parent_path() / "cameras.json").string(), keepCrs);
model.save(outputScene);
// model.saveDebugPly("debug.ply");
model.save(outputScene, numIters);
// model.saveDebugPly("debug.ply", numIters);

// Validate
if (valCam != nullptr){
Expand Down
3 changes: 3 additions & 0 deletions utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include <thread>
#include <functional>

#define RELEASE_SAFELY(__POINTER) { if (__POINTER != nullptr) { delete __POINTER; __POINTER = nullptr; } }


template <typename T>
class InfiniteRandomIterator
{
Expand Down

0 comments on commit dfa162a

Please sign in to comment.