diff --git a/src/Arduino_ESP32_OTA.cpp b/src/Arduino_ESP32_OTA.cpp index ec1cc1b..76c1d02 100644 --- a/src/Arduino_ESP32_OTA.cpp +++ b/src/Arduino_ESP32_OTA.cpp @@ -164,10 +164,10 @@ int Arduino_ESP32_OTA::startDownload(const char * ota_url) } } -Arduino_ESP32_OTA::OTADownloadState Arduino_ESP32_OTA::progressDownload() +int Arduino_ESP32_OTA::progressDownload() { - int http_res = 0; - Arduino_ESP32_OTA::OTADownloadState res = OtaDownloadHeader; + int http_res = static_cast(Error::None);; + int res = 0; if(_http_client->available() == 0) { goto exit; @@ -177,7 +177,7 @@ Arduino_ESP32_OTA::OTADownloadState Arduino_ESP32_OTA::progressDownload() if(http_res < 0) { DEBUG_VERBOSE("OTA ERROR: Download read error %d", http_res); - res = OtaDownloadError; + res = static_cast(Error::OtaDownload); goto exit; } @@ -201,7 +201,7 @@ Arduino_ESP32_OTA::OTADownloadState Arduino_ESP32_OTA::progressDownload() if(_context->header.header.magic_number != _magic) { _context->downloadState = OtaDownloadMagicNumberMismatch; - res = _context->downloadState; + res = static_cast(Error::OtaHeaderMagicNumber); goto exit; } @@ -224,20 +224,21 @@ Arduino_ESP32_OTA::OTADownloadState Arduino_ESP32_OTA::progressDownload() // TODO there should be no more bytes available when the download is completed if(_context->downloadedSize == _http_client->contentLength()) { _context->downloadState = OtaDownloadCompleted; - res = _context->downloadState; + res = 1; } if(_context->downloadedSize > _http_client->contentLength()) { _context->downloadState = OtaDownloadError; - res = _context->downloadState; + res = static_cast(Error::OtaDownload); } // TODO fail if we exceed a timeout? and available is 0 (client is broken) break; case OtaDownloadCompleted: + res = 1; goto exit; default: _context->downloadState = OtaDownloadError; - res = _context->downloadState; + res = static_cast(Error::OtaDownload); goto exit; } } @@ -264,7 +265,16 @@ Arduino_ESP32_OTA::OTADownloadState Arduino_ESP32_OTA::progressDownload() int Arduino_ESP32_OTA::downloadProgress() { - return _context->downloadedSize; + if(_context->error != Error::None) { + return static_cast(_context->error); + } else { + return _context->downloadedSize; + } +} + +size_t Arduino_ESP32_OTA::downloadSize() +{ + return _http_client!=nullptr ? _http_client->contentLength() : 0; } int Arduino_ESP32_OTA::download(const char * ota_url) @@ -275,22 +285,10 @@ int Arduino_ESP32_OTA::download(const char * ota_url) return err; } - OTADownloadState res = OtaDownloadHeader; - - while((res = progressDownload()) == OtaDownloadFile || res == OtaDownloadHeader); + int res = 0; + while((res = progressDownload()) <= 0); - - if(res == OtaDownloadCompleted) { - return _context->writtenBytes; - } else { - switch(res) { - case OtaDownloadMagicNumberMismatch: - return static_cast(Error::OtaHeaderMagicNumber); - case OtaDownloadError: - default: - return static_cast(Error::OtaDownload); - } - } + return res == 1? _context->writtenBytes : res; } void Arduino_ESP32_OTA::clean() @@ -368,6 +366,7 @@ Arduino_ESP32_OTA::Context::Context( , calculatedCrc32(0xFFFFFFFF) , headerCopiedBytes(0) , downloadedSize(0) + , error(Error::None) , decoder(putc) { strcpy(this->url, url); }