diff --git a/src/ArduinoIoTCloudTCP.h b/src/ArduinoIoTCloudTCP.h index 29dfc754..f1682164 100644 --- a/src/ArduinoIoTCloudTCP.h +++ b/src/ArduinoIoTCloudTCP.h @@ -96,9 +96,18 @@ class ArduinoIoTCloudTCP: public ArduinoIoTCloudClass _get_ota_confirmation = cb; if(_get_ota_confirmation) { - _ota.setOtaPolicies(OTACloudProcessInterface::ApprovalRequired); + _ota.enableOtaPolicy(OTACloudProcessInterface::ApprovalRequired); } else { - _ota.setOtaPolicies(OTACloudProcessInterface::None); + _ota.disableOtaPolicy(OTACloudProcessInterface::ApprovalRequired); + } + } + + /* Slower but more reliable in some corner cases */ + void setOTAChunkMode(bool enable = true) { + if(enable) { + _ota.enableOtaPolicy(OTACloudProcessInterface::ChunkDownload); + } else { + _ota.disableOtaPolicy(OTACloudProcessInterface::ChunkDownload); } } #endif diff --git a/src/ota/interface/OTAInterface.cpp b/src/ota/interface/OTAInterface.cpp index b659917b..e72d062d 100644 --- a/src/ota/interface/OTAInterface.cpp +++ b/src/ota/interface/OTAInterface.cpp @@ -167,10 +167,10 @@ OTACloudProcessInterface::State OTACloudProcessInterface::idle(Message* msg) { OTACloudProcessInterface::State OTACloudProcessInterface::otaAvailable() { // depending on the policy decided on this device the ota process can start immediately // or wait for confirmation from the user - if((policies & (ApprovalRequired | Approved)) == ApprovalRequired ) { + if(getOtaPolicy(ApprovalRequired) && !getOtaPolicy(Approved)) { return OtaAvailable; } else { - policies &= ~Approved; + disableOtaPolicy(Approved); return StartOTA; } // TODO add an abortOTA command? in this case delete the context } diff --git a/src/ota/interface/OTAInterface.h b/src/ota/interface/OTAInterface.h index a62b7cb2..d9624c3b 100644 --- a/src/ota/interface/OTAInterface.h +++ b/src/ota/interface/OTAInterface.h @@ -80,7 +80,8 @@ class OTACloudProcessInterface: public CloudProcess { enum OtaFlags: uint16_t { None = 0, ApprovalRequired = 1, - Approved = 1<<1 + Approved = 1<<1, + ChunkDownload = 1<<2 }; virtual void handleMessage(Message*); @@ -88,9 +89,13 @@ class OTACloudProcessInterface: public CloudProcess { // virtual void hook(State s, void* action); virtual void update() { handleMessage(nullptr); } - inline void approveOta() { policies |= Approved; } + inline void approveOta() { this->policies |= Approved; } inline void setOtaPolicies(uint16_t policies) { this->policies = policies; } + inline void enableOtaPolicy(OtaFlags policyFlag) { this->policies |= policyFlag; } + inline void disableOtaPolicy(OtaFlags policyFlag) { this->policies &= ~policyFlag; } + inline bool getOtaPolicy(OtaFlags policyFlag) { return (this->policies & policyFlag) != 0;} + inline State getState() { return state; } virtual bool isOtaCapable() = 0; diff --git a/src/ota/interface/OTAInterfaceDefault.cpp b/src/ota/interface/OTAInterfaceDefault.cpp index 82bfd9e8..a641595b 100644 --- a/src/ota/interface/OTAInterfaceDefault.cpp +++ b/src/ota/interface/OTAInterfaceDefault.cpp @@ -41,39 +41,17 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::startOTA() { } ); - // make the http get request + // check url if(strcmp(context->parsed_url.schema(), "https") == 0) { http_client = new HttpClient(*client, context->parsed_url.host(), context->parsed_url.port()); } else { return UrlParseErrorFail; } - http_client->beginRequest(); - auto res = http_client->get(context->parsed_url.path()); - - if(username != nullptr && password != nullptr) { - http_client->sendBasicAuth(username, password); - } - - http_client->endRequest(); - - if(res == HTTP_ERROR_CONNECTION_FAILED) { - DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"", - context->parsed_url.host(), context->parsed_url.port()); - return ServerConnectErrorFail; - } else if(res == HTTP_ERROR_TIMED_OUT) { - DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url); - return OtaHeaderTimeoutFail; - } else if(res != HTTP_SUCCESS) { - DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", res, OTACloudProcessInterface::context->url); - return OtaDownloadFail; - } - - int statusCode = http_client->responseStatusCode(); - - if(statusCode != 200) { - DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode); - return HttpResponseFail; + // make the http get request + OTACloudProcessInterface::State res = requestOta(); + if(res != Fetch) { + return res; } // The following call is required to save the header value , keep it @@ -82,16 +60,27 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::startOTA() { return HttpHeaderErrorFail; } + context->contentLength = http_client->contentLength(); context->lastReportTime = millis(); - + DEBUG_VERBOSE("OTA file length: %d", context->contentLength); return Fetch; } OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() { OTACloudProcessInterface::State res = Fetch; - int http_res = 0; - uint32_t start = millis(); + if(getOtaPolicy(ChunkDownload)) { + res = requestOta(ChunkDownload); + } + + context->downloadedChunkSize = 0; + context->downloadedChunkStartTime = millis(); + + if(res != Fetch) { + goto exit; + } + + /* download chunked or timed */ do { if(!http_client->connected()) { res = OtaDownloadFail; @@ -104,7 +93,7 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() { continue; } - http_res = http_client->read(context->buffer, context->buf_len); + int http_res = http_client->read(context->buffer, context->bufLen); if(http_res < 0) { DEBUG_VERBOSE("OTA ERROR: Download read error %d", http_res); @@ -119,8 +108,10 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() { res = ErrorWriteUpdateFileFail; goto exit; } - } while((context->downloadState == OtaDownloadFile || context->downloadState == OtaDownloadHeader) && - millis() - start < downloadTime); + + context->downloadedChunkSize += http_res; + + } while(context->downloadState < OtaDownloadCompleted && fetchMore()); // TODO verify that the information present in the ota header match the info in context if(context->downloadState == OtaDownloadCompleted) { @@ -153,13 +144,69 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() { return res; } -void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len) { +OTACloudProcessInterface::State OTADefaultCloudProcessInterface::requestOta(OtaFlags mode) { + int http_res = 0; + + /* stop connected client */ + http_client->stop(); + + /* request chunk */ + http_client->beginRequest(); + http_res = http_client->get(context->parsed_url.path()); + + if(username != nullptr && password != nullptr) { + http_client->sendBasicAuth(username, password); + } + + if((mode & ChunkDownload) == ChunkDownload) { + char range[128] = {0}; + size_t rangeSize = context->downloadedSize + maxChunkSize > context->contentLength ? context->contentLength - context->downloadedSize : maxChunkSize; + sprintf(range, "bytes=%" PRIu32 "-%" PRIu32, context->downloadedSize, context->downloadedSize + rangeSize); + DEBUG_VERBOSE("OTA downloading range: %s", range); + http_client->sendHeader("Range", range); + } + + http_client->endRequest(); + + if(http_res == HTTP_ERROR_CONNECTION_FAILED) { + DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"", + context->parsed_url.host(), context->parsed_url.port()); + return ServerConnectErrorFail; + } else if(http_res == HTTP_ERROR_TIMED_OUT) { + DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url); + return OtaHeaderTimeoutFail; + } else if(http_res != HTTP_SUCCESS) { + DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", http_res, OTACloudProcessInterface::context->url); + return OtaDownloadFail; + } + + int statusCode = http_client->responseStatusCode(); + + if((((mode & ChunkDownload) == ChunkDownload) && (statusCode != 206)) || + (((mode & ChunkDownload) != ChunkDownload) && (statusCode != 200))) { + DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode); + return HttpResponseFail; + } + + http_client->skipResponseHeaders(); + return Fetch; +} + +bool OTADefaultCloudProcessInterface::fetchMore() { + if (getOtaPolicy(ChunkDownload)) { + return context->downloadedChunkSize < maxChunkSize; + } else { + return (millis() - context->downloadedChunkStartTime) < downloadTime; + } +} + +void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t bufLen) { assert(context != nullptr); // This should never fail - for(uint8_t* cursor=(uint8_t*)buffer; cursordownloadState) { case OtaDownloadHeader: { - const uint32_t headerLeft = context->headerCopiedBytes + buf_len <= sizeof(context->header.buf) ? buf_len : sizeof(context->header.buf) - context->headerCopiedBytes; + const uint32_t headerLeft = context->headerCopiedBytes + bufLen <= sizeof(context->header.buf) ? bufLen : sizeof(context->header.buf) - context->headerCopiedBytes; memcpy(context->header.buf+context->headerCopiedBytes, buffer, headerLeft); cursor += headerLeft; context->headerCopiedBytes += headerLeft; @@ -184,8 +231,7 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len) break; } case OtaDownloadFile: { - const uint32_t contentLength = http_client->contentLength(); - const uint32_t dataLeft = buf_len - (cursor-buffer); + const uint32_t dataLeft = bufLen - (cursor-buffer); context->decoder.decompress(cursor, dataLeft); // TODO verify return value context->calculatedCrc32 = crc_update( @@ -198,18 +244,18 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len) context->downloadedSize += dataLeft; if((millis() - context->lastReportTime) > 10000) { // Report the download progress each X millisecond - DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, contentLength); + DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, context->contentLength); reportStatus(context->downloadedSize); context->lastReportTime = millis(); } // TODO there should be no more bytes available when the download is completed - if(context->downloadedSize == contentLength) { + if(context->downloadedSize == context->contentLength) { context->downloadState = OtaDownloadCompleted; } - if(context->downloadedSize > contentLength) { + if(context->downloadedSize > context->contentLength) { context->downloadState = OtaDownloadError; } // TODO fail if we exceed a timeout? and available is 0 (client is broken) @@ -250,7 +296,9 @@ OTADefaultCloudProcessInterface::Context::Context( , headerCopiedBytes(0) , downloadedSize(0) , lastReportTime(0) + , contentLength(0) , writeError(false) + , downloadedChunkSize(0) , decoder(putc) { } static const uint32_t crc_table[256] = { diff --git a/src/ota/interface/OTAInterfaceDefault.h b/src/ota/interface/OTAInterfaceDefault.h index 95384817..b45c14dd 100644 --- a/src/ota/interface/OTAInterfaceDefault.h +++ b/src/ota/interface/OTAInterfaceDefault.h @@ -42,7 +42,9 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface { virtual int writeFlash(uint8_t* const buffer, size_t len) = 0; private: - void parseOta(uint8_t* buffer, size_t buf_len); + void parseOta(uint8_t* buffer, size_t bufLen); + State requestOta(OtaFlags mode = None); + bool fetchMore(); Client* client; HttpClient* http_client; @@ -53,6 +55,10 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface { // This mitigate the issues arising from tasks run in main loop that are using all the computing time static constexpr uint32_t downloadTime = 2000; + // The amount of data that each iteration of Fetch has to take at least + // This should be enabled setting ChunkDownload OtaFlag to 1 and mitigate some Ota corner cases + static constexpr size_t maxChunkSize = 1024 * 10; + enum OTADownloadState: uint8_t { OtaDownloadHeader, OtaDownloadFile, @@ -74,13 +80,17 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface { uint32_t headerCopiedBytes; uint32_t downloadedSize; uint32_t lastReportTime; + uint32_t contentLength; bool writeError; + uint32_t downloadedChunkStartTime; + uint32_t downloadedChunkSize; + // LZSS decoder LZSSDecoder decoder; - const size_t buf_len = 64; - uint8_t buffer[64]; + static constexpr size_t bufLen = 64; + uint8_t buffer[bufLen]; } *context; };