diff --git a/loader/src/image.hpp b/loader/src/image.hpp index 04b4b64c..35b5353c 100644 --- a/loader/src/image.hpp +++ b/loader/src/image.hpp @@ -268,7 +268,8 @@ friend class Video; createRandomAugParams(decodedDatum.size()); transformDecodedImage(decodedDatum, datumBuf, datumLen); Mat decodedTarget; - decode(encTarget, encTargetLen, &decodedTarget); + // Assume grayscale masks for now. + decodeGrayscale(encTarget, encTargetLen, &decodedTarget); transformDecodedImage(decodedTarget, targetBuf, targetLen); } @@ -367,13 +368,21 @@ friend class Video; } private: + void decodeGrayscale(char* item, int itemSize, Mat* dst) { + Mat image(1, itemSize, CV_8UC1, item); + cv::imdecode(image, CV_LOAD_IMAGE_GRAYSCALE, dst); + } + + void decodeColor(char* item, int itemSize, Mat* dst) { + Mat image(1, itemSize, CV_8UC3, item); + cv::imdecode(image, CV_LOAD_IMAGE_COLOR, dst); + } + void decode(char* item, int itemSize, Mat* dst) { if (_params->_channelCount == 1) { - Mat image(1, itemSize, CV_8UC1, item); - cv::imdecode(image, CV_LOAD_IMAGE_GRAYSCALE, dst); + decodeGrayscale(item, itemSize, dst); } else if (_params->_channelCount == 3) { - Mat image(1, itemSize, CV_8UC3, item); - cv::imdecode(image, CV_LOAD_IMAGE_COLOR, dst); + decodeColor(item, itemSize, dst); } else { stringstream ss; ss << "Unsupported number of channels in image: " << _params->_channelCount;