diff --git a/.bazelversion b/.bazelversion index 04edabd..f9da12e 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -5.4.1 \ No newline at end of file +6.3.2 \ No newline at end of file diff --git a/.circleci/config.yml b/.circleci/config.yml index ea0a898..ce9a937 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,12 +8,12 @@ version: 2.1 executors: linux_x64_executor: # declares a reusable executor docker: - - image: envoyproxy/envoy-build-ubuntu:81a93046060dbe5620d5b3aa92632090a9ee4da6 + - image: envoyproxy/envoy-build-ubuntu:0ca52447572ee105a4730da5e76fe47c9c5a7c64 resource_class: 2xlarge shell: /bin/bash --login -eo pipefail linux_aarch64_executor: docker: - - image: envoyproxy/envoy-build-ubuntu:81a93046060dbe5620d5b3aa92632090a9ee4da6 + - image: envoyproxy/envoy-build-ubuntu:0ca52447572ee105a4730da5e76fe47c9c5a7c64 resource_class: arm.2xlarge shell: /bin/bash --login -eo pipefail @@ -54,26 +54,24 @@ jobs: IMG=secretflow/kuscia-envoy IMG_LATEST={IMG}:latest IMG_TAG={IMG}:{CIRCLETAG} - + ALIYUN_IMG=secretflow-registry.cn-hangzhou.cr.aliyuncs.com/secretflow/kuscia-envoy ALIYUN_IMG_LATEST={ALIYUN_IMG}:latest ALIYUN_IMG_TAG={ALIYUN_IMG}:{CIRCLETAG} - + #login docker docker login -u ${DOCKER_USERNAME} -p ${DOCKER_DEPLOY_TOKEN} - - docker buildx build -t ${IMG_LATEST} --platform linux/amd64 --build-arg ARCH=amd64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push - docker buildx build -t ${IMG_LATEST} --platform linux/arm64 --build-arg ARCH=arm64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push - docker buildx build -t ${IMG_TAG} --platform linux/amd64 --build-arg ARCH=amd64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push - docker buildx build -t ${IMG_TAG} --platform linux/arm64 --build-arg ARCH=arm64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push - + + docker buildx build -t ${IMG_LATEST} --platform linux/arm64,linux/amd64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push + docker buildx build -t ${IMG_TAG} --platform linux/arm64,linux/amd64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push + + # login docker - aliyun docker login -u ${ALIYUN_DOCKER_USERNAME} -p ${ALIYUN_DOCKER_PASSWORD} secretflow-registry.cn-hangzhou.cr.aliyuncs.com - - docker buildx build -t ${ALIYUN_IMG_LATEST} --platform linux/amd64 --build-arg ARCH=amd64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push - docker buildx build -t ${ALIYUN_IMG_LATEST} --platform linux/arm64 --build-arg ARCH=arm64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push - docker buildx build -t ${ALIYUN_IMG_TAG} --platform linux/amd64 --build-arg ARCH=amd64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push - docker buildx build -t ${ALIYUN_IMG_TAG} --platform linux/arm64 --build-arg ARCH=arm64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push + docker buildx build -t {ALIYUN_IMG_LATEST} --platform linux/amd64,linux/arm64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push + docker buildx build -t {ALIYUN_IMG_TAG} --platform linux/amd64,linux/arm64 -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . --push + + # Orchestrate jobs using workflows diff --git a/.gitmodules b/.gitmodules index d27b8a8..57d02e3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "envoy"] path = envoy url = https://github.com/envoyproxy/envoy.git - branch = release/v1.20 + branch = release/v1.29 diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c32734..e94b23b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `Fixed` for any bug fixes. `Security` in case of vulnerabilities. +## [v0.5.0.dev240430] - 2024-04-30 +### Added +- [Feature] Support for ARM architecture. +- [Feature] Support for reverse tunneling multiple replicas. + +### Changed +- [Upgrade] Upgraded the dependent Envoy version to 1.29.4. + ## [0.2.0b0] - 2023-7-6 ### Added - Kuscia-envoy init release diff --git a/Makefile b/Makefile index a9ded2e..2c9a0e0 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ SHELL := /bin/bash -BUILD_IMAGE = envoyproxy/envoy-build-ubuntu:81a93046060dbe5620d5b3aa92632090a9ee4da6 +BUILD_IMAGE = envoyproxy/envoy-build-ubuntu:0ca52447572ee105a4730da5e76fe47c9c5a7c64 # Image URL to use all building image targets DATETIME = $(shell date +"%Y%m%d%H%M%S") @@ -14,7 +14,7 @@ UNAME_M_OUTPUT := $(shell uname -m) # To configure the ARCH variable to either arm64 or amd64 or UNAME_M_OUTPUT ARCH := $(if $(filter aarch64 arm64,$(UNAME_M_OUTPUT)),arm64,$(if $(filter amd64 x86_64,$(UNAME_M_OUTPUT)),amd64,$(UNAME_M_OUTPUT))) -CONTAINER_NAME ?= "build-envoy" +CONTAINER_NAME ?= "build-envoy-$(shell echo ${USER})" COMPILE_MODE ?=opt TARGET ?= "//:envoy" BUILD_OPTS ?="--strip=always" @@ -28,7 +28,7 @@ define start_docker git submodule update --init;\ fi; if [[ ! -n $$(docker ps -q -f "name=^$(CONTAINER_NAME)$$") ]]; then\ - docker run -itd --rm -v $(shell pwd):/home/admin/dev -v $(shell pwd)/cache:/root/.cache/bazel -w /home/admin/dev --name $(CONTAINER_NAME) \ + docker run -itd --rm -v $(shell pwd)/cache:/root/.cache/bazel -v $(shell pwd):/home/admin/dev -w /home/admin/dev --name $(CONTAINER_NAME) \ -e GOPROXY='https://goproxy.cn,direct' --cap-add=NET_ADMIN $(BUILD_IMAGE);\ docker exec -it $(CONTAINER_NAME) /bin/bash -c 'git config --global --add safe.directory /home/admin/dev';\ fi; @@ -72,7 +72,6 @@ clean: $(call stop_docker) rm -rf output - .PHONY: image image: build-envoy - docker build -t ${IMG} --build-arg ARCH=${ARCH} -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . + docker build -t ${IMG} -f ./build_image/dockerfile/kuscia-envoy-anolis.Dockerfile . diff --git a/WORKSPACE b/WORKSPACE index 72a6f05..ed99707 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -17,9 +17,10 @@ load("@envoy//bazel:repositories.bzl", "envoy_dependencies") envoy_dependencies() +#https://github.com/envoyproxy/envoy/issues/28670 load("@envoy//bazel:repositories_extra.bzl", "envoy_dependencies_extra") -envoy_dependencies_extra() +envoy_dependencies_extra(ignore_root_user_error = True) load("@envoy//bazel:python_dependencies.bzl", "envoy_python_dependencies") diff --git a/build_image/dockerfile/kuscia-envoy-anolis.Dockerfile b/build_image/dockerfile/kuscia-envoy-anolis.Dockerfile index 6c2af53..2e303de 100644 --- a/build_image/dockerfile/kuscia-envoy-anolis.Dockerfile +++ b/build_image/dockerfile/kuscia-envoy-anolis.Dockerfile @@ -1,12 +1,12 @@ FROM openanolis/anolisos:8.8 -ARG ARCH +ARG TARGETPLATFORM ENV TZ=Asia/Shanghai ARG ROOT_DIR="/home/kuscia" -COPY ./output/linux/$ARCH $ROOT_DIR/ +COPY ./output/$TARGETPLATFORM $ROOT_DIR/ WORKDIR ${ROOT_DIR} diff --git a/envoy b/envoy index bae2e9d..8eef22b 160000 --- a/envoy +++ b/envoy @@ -1 +1 @@ -Subproject commit bae2e9d642a6a8ae6c5d3810f77f3e888f0d97da +Subproject commit 8eef22b927682e9ff6f59cf9f26e440b41219fe6 diff --git a/kuscia/api/filters/http/kuscia_crypt/v3/BUILD b/kuscia/api/filters/http/kuscia_crypt/v3/BUILD index c496ff4..b514f18 100755 --- a/kuscia/api/filters/http/kuscia_crypt/v3/BUILD +++ b/kuscia/api/filters/http/kuscia_crypt/v3/BUILD @@ -6,6 +6,6 @@ licenses(["notice"]) # Apache 2 api_proto_package( deps = [ - "@com_github_cncf_udpa//udpa/annotations:pkg", + "@com_github_cncf_xds//udpa/annotations:pkg", ], ) diff --git a/kuscia/api/filters/http/kuscia_gress/v3/BUILD b/kuscia/api/filters/http/kuscia_gress/v3/BUILD index 95ec17e..d8773a7 100755 --- a/kuscia/api/filters/http/kuscia_gress/v3/BUILD +++ b/kuscia/api/filters/http/kuscia_gress/v3/BUILD @@ -6,7 +6,7 @@ licenses(["notice"]) # Apache 2 api_proto_package( deps = [ - "@com_github_cncf_udpa//udpa/annotations:pkg", - "@envoy_api//envoy/type/matcher/v3:pkg", + "@com_github_cncf_xds//udpa/annotations:pkg", + "@envoy_api//envoy/type/matcher/v3:pkg", ], ) diff --git a/kuscia/api/filters/http/kuscia_header_decorator/v3/BUILD b/kuscia/api/filters/http/kuscia_header_decorator/v3/BUILD index c496ff4..b514f18 100644 --- a/kuscia/api/filters/http/kuscia_header_decorator/v3/BUILD +++ b/kuscia/api/filters/http/kuscia_header_decorator/v3/BUILD @@ -6,6 +6,6 @@ licenses(["notice"]) # Apache 2 api_proto_package( deps = [ - "@com_github_cncf_udpa//udpa/annotations:pkg", + "@com_github_cncf_xds//udpa/annotations:pkg", ], ) diff --git a/kuscia/api/filters/http/kuscia_poller/v3/BUILD b/kuscia/api/filters/http/kuscia_poller/v3/BUILD index c496ff4..b514f18 100755 --- a/kuscia/api/filters/http/kuscia_poller/v3/BUILD +++ b/kuscia/api/filters/http/kuscia_poller/v3/BUILD @@ -6,6 +6,6 @@ licenses(["notice"]) # Apache 2 api_proto_package( deps = [ - "@com_github_cncf_udpa//udpa/annotations:pkg", + "@com_github_cncf_xds//udpa/annotations:pkg", ], ) diff --git a/kuscia/api/filters/http/kuscia_receiver/v3/BUILD b/kuscia/api/filters/http/kuscia_receiver/v3/BUILD index c496ff4..b514f18 100644 --- a/kuscia/api/filters/http/kuscia_receiver/v3/BUILD +++ b/kuscia/api/filters/http/kuscia_receiver/v3/BUILD @@ -6,6 +6,6 @@ licenses(["notice"]) # Apache 2 api_proto_package( deps = [ - "@com_github_cncf_udpa//udpa/annotations:pkg", + "@com_github_cncf_xds//udpa/annotations:pkg", ], ) diff --git a/kuscia/api/filters/http/kuscia_token_auth/v3/BUILD b/kuscia/api/filters/http/kuscia_token_auth/v3/BUILD index c496ff4..b514f18 100755 --- a/kuscia/api/filters/http/kuscia_token_auth/v3/BUILD +++ b/kuscia/api/filters/http/kuscia_token_auth/v3/BUILD @@ -6,6 +6,6 @@ licenses(["notice"]) # Apache 2 api_proto_package( deps = [ - "@com_github_cncf_udpa//udpa/annotations:pkg", + "@com_github_cncf_xds//udpa/annotations:pkg", ], ) diff --git a/kuscia/source/filters/http/kuscia_common/coder.cc b/kuscia/source/filters/http/kuscia_common/coder.cc index 02cf8aa..f14621e 100644 --- a/kuscia/source/filters/http/kuscia_common/coder.cc +++ b/kuscia/source/filters/http/kuscia_common/coder.cc @@ -19,20 +19,20 @@ namespace Extensions { namespace HttpFilters { namespace KusciaCommon { -DecodeStatus KusciaCommon::Decoder::decode(Envoy::Buffer::Instance& data, google::protobuf::Message &message) -{ - DecodeStatus status = frameReader_.read(data); - if (status != DecodeStatus::Ok) { - return status; - } +DecodeStatus KusciaCommon::Decoder::decode(Envoy::Buffer::Instance& data, + google::protobuf::Message& message) { + DecodeStatus status = frameReader_.read(data); + if (status != DecodeStatus::Ok) { + return status; + } - auto data_frame = frameReader_.getDataFrame(); + auto data_frame = frameReader_.getDataFrame(); - if (!message.ParseFromArray(data_frame.data(), data_frame.size())) { - return DecodeStatus::ErrorInvalidData; - } + if (!message.ParseFromArray(data_frame.data(), data_frame.size())) { + return DecodeStatus::ErrorInvalidData; + } - return DecodeStatus::Ok; + return DecodeStatus::Ok; } } // namespace KusciaCommon diff --git a/kuscia/source/filters/http/kuscia_common/coder.h b/kuscia/source/filters/http/kuscia_common/coder.h index 983380b..c332aca 100644 --- a/kuscia/source/filters/http/kuscia_common/coder.h +++ b/kuscia/source/filters/http/kuscia_common/coder.h @@ -24,10 +24,10 @@ namespace KusciaCommon { class Decoder { public: - DecodeStatus decode(Envoy::Buffer::Instance& data, google::protobuf::Message& message); + DecodeStatus decode(Envoy::Buffer::Instance& data, google::protobuf::Message& message); private: - LengthDelimitedFrameReader frameReader_; + LengthDelimitedFrameReader frameReader_; }; } // namespace KusciaCommon diff --git a/kuscia/source/filters/http/kuscia_common/framer.cc b/kuscia/source/filters/http/kuscia_common/framer.cc index f79d7ba..03844b7 100644 --- a/kuscia/source/filters/http/kuscia_common/framer.cc +++ b/kuscia/source/filters/http/kuscia_common/framer.cc @@ -12,88 +12,86 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include #include +#include #include +#include -#include "kuscia/source/filters/http/kuscia_common/framer.h" #include "framer.h" +#include "kuscia/source/filters/http/kuscia_common/framer.h" namespace Envoy { namespace Extensions { namespace HttpFilters { namespace KusciaCommon { - static const std::map decodeStatusMessageMap = { {DecodeStatus::NeedMoreData, "need more data"}, {DecodeStatus::ErrorObjectTooLarge, "object too large"}, - {DecodeStatus::ErrorInvalidData, "invalid data"} -}; + {DecodeStatus::ErrorInvalidData, "invalid data"}}; absl::string_view decodeStatusString(DecodeStatus status) { - auto it = decodeStatusMessageMap.find(status); - if (it != decodeStatusMessageMap.end()) { - return it->second; - } else { - return ""; - } + auto it = decodeStatusMessageMap.find(status); + if (it != decodeStatusMessageMap.end()) { + return it->second; + } else { + return ""; + } } -DecodeStatus LengthDelimitedFrameReader::read(Buffer::Instance& input) -{ - if (remaining_ == 0) { - uint32_t frameLength; - if (!readByLen(input, sizeof(frameLength), len_frame_)) { - return DecodeStatus::NeedMoreData; - } - - std::memcpy(&frameLength, len_frame_.data(), sizeof(frameLength)); - remaining_ = ntohl(frameLength); +DecodeStatus LengthDelimitedFrameReader::read(Buffer::Instance& input) { + if (remaining_ == 0) { + uint32_t frameLength; + if (!readByLen(input, sizeof(frameLength), len_frame_)) { + return DecodeStatus::NeedMoreData; + } - if (remaining_ > maxBytes_) { - return DecodeStatus::ErrorObjectTooLarge; - } + std::memcpy(&frameLength, len_frame_.data(), sizeof(frameLength)); + remaining_ = ntohl(frameLength); - len_frame_.resize(0); - data_frame_.resize(0); + if (remaining_ > maxBytes_) { + return DecodeStatus::ErrorObjectTooLarge; } - if (!readByLen(input, remaining_, data_frame_)) { - return DecodeStatus::NeedMoreData; - } + len_frame_.resize(0); + data_frame_.resize(0); + } + + if (!readByLen(input, remaining_, data_frame_)) { + return DecodeStatus::NeedMoreData; + } - remaining_ = 0; - return DecodeStatus::Ok; + remaining_ = 0; + return DecodeStatus::Ok; } -bool LengthDelimitedFrameReader::readByLen(Buffer::Instance& input, size_t len, std::vector& frame) -{ - size_t frame_size = frame.size(); - size_t input_len = input.length(); +bool LengthDelimitedFrameReader::readByLen(Buffer::Instance& input, size_t len, + std::vector& frame) { + size_t frame_size = frame.size(); + size_t input_len = input.length(); - if (frame_size + input_len < len) { - ENVOY_LOG(info, "Need more input data, frame size: {} + input-len: {} < {}", frame_size, input_len, len); + if (frame_size + input_len < len) { + ENVOY_LOG(info, "Need more input data, frame size: {} + input-len: {} < {}", frame_size, + input_len, len); - frame.resize(frame_size + input_len); - input.copyOut(0, input_len, frame.data() + frame_size); - input.drain(input_len); - return false; // need more input data - } + frame.resize(frame_size + input_len); + input.copyOut(0, input_len, frame.data() + frame_size); + input.drain(input_len); + return false; // need more input data + } - frame.resize(len); - input.copyOut(0, len - frame_size, frame.data() + frame_size); - input.drain(len - frame_size); + frame.resize(len); + input.copyOut(0, len - frame_size, frame.data() + frame_size); + input.drain(len - frame_size); - return true; + return true; } -void KusciaCommon::LengthDelimitedFrameWriter::write(const char data[], uint32_t size, Buffer::OwnedImpl &output) -{ - uint32_t net_size = htonl(size); - output.add(&net_size, sizeof(net_size)); - output.add(data, size); +void KusciaCommon::LengthDelimitedFrameWriter::write(const char data[], uint32_t size, + Buffer::OwnedImpl& output) { + uint32_t net_size = htonl(size); + output.add(&net_size, sizeof(net_size)); + output.add(data, size); } } // namespace KusciaCommon diff --git a/kuscia/source/filters/http/kuscia_common/framer.h b/kuscia/source/filters/http/kuscia_common/framer.h index c0df2ad..6f61ecf 100644 --- a/kuscia/source/filters/http/kuscia_common/framer.h +++ b/kuscia/source/filters/http/kuscia_common/framer.h @@ -14,9 +14,9 @@ #pragma once -#include #include "source/common/buffer/buffer_impl.h" #include "source/common/common/logger.h" +#include namespace Envoy { namespace Extensions { @@ -24,34 +24,34 @@ namespace HttpFilters { namespace KusciaCommon { enum class DecodeStatus { - Ok, - NeedMoreData, - ErrorObjectTooLarge, - ErrorInvalidData, + Ok, + NeedMoreData, + ErrorObjectTooLarge, + ErrorInvalidData, }; -extern absl::string_view decodeStatusString(DecodeStatus status); +extern absl::string_view decodeStatusString(DecodeStatus status); class LengthDelimitedFrameReader : public Logger::Loggable { public: - LengthDelimitedFrameReader() : remaining_(0), maxBytes_(16 * 1024 * 1024) {} + LengthDelimitedFrameReader() : remaining_(0), maxBytes_(16 * 1024 * 1024) {} - DecodeStatus read(Buffer::Instance& input); + DecodeStatus read(Buffer::Instance& input); - const std::vector& getDataFrame() const { return data_frame_; }; + const std::vector& getDataFrame() const { return data_frame_; }; private: - bool readByLen(Buffer::Instance& input, size_t len, std::vector& frame); + bool readByLen(Buffer::Instance& input, size_t len, std::vector& frame); - size_t remaining_; - const size_t maxBytes_; - std::vector len_frame_; - std::vector data_frame_; + size_t remaining_; + const size_t maxBytes_; + std::vector len_frame_; + std::vector data_frame_; }; class LengthDelimitedFrameWriter { public: - void write(const char data[], uint32_t size, Buffer::OwnedImpl& output); + void write(const char data[], uint32_t size, Buffer::OwnedImpl& output); }; } // namespace KusciaCommon diff --git a/kuscia/source/filters/http/kuscia_common/kuscia_header.cc b/kuscia/source/filters/http/kuscia_common/kuscia_header.cc index c22b790..0cff8de 100644 --- a/kuscia/source/filters/http/kuscia_common/kuscia_header.cc +++ b/kuscia/source/filters/http/kuscia_common/kuscia_header.cc @@ -22,22 +22,40 @@ namespace KusciaCommon { constexpr absl::string_view InterConnProtocolBFIA{"bfia"}; constexpr absl::string_view InterConnProtocolKuscia{"kuscia"}; -absl::optional -KusciaHeader::getSource(const Http::RequestHeaderMap &headers) { - auto kusciaSource = headers.getByKey(HeaderKeyKusciaSource); - if (kusciaSource) { - return kusciaSource; +absl::optional KusciaHeader::getSource(const Http::RequestHeaderMap& headers) { + absl::string_view kusciaSource; + auto source = headers.get(HeaderKeyKusciaSource); + if (!source.empty()) { + return source[0]->value().getStringView(); } // BFIA protocol - auto protocol = headers.getByKey(KusciaCommon::HeaderKeyInterConnProtocol); - if (protocol && protocol.value() == InterConnProtocolBFIA) { - auto ptpSource = headers.getByKey(HeaderKeyBFIAPTPSource); - return ptpSource ? ptpSource - : headers.getByKey(HeaderKeyBFIAScheduleSource); + auto protocol = headers.get(KusciaCommon::HeaderKeyInterConnProtocol); + if (!protocol.empty() && std::string(protocol[0]->value().getStringView()) == InterConnProtocolBFIA) { + auto ptpSource = headers.get(HeaderKeyBFIAPTPSource); + if (!ptpSource.empty()) { + return ptpSource[0]->value().getStringView(); + } + + auto scheduleSource = headers.get(HeaderKeyBFIAScheduleSource); + if (!scheduleSource.empty()) { + return scheduleSource[0]->value().getStringView(); + } } return kusciaSource; } +void adjustContentLength(Http::RequestOrResponseHeaderMap& headers, int64_t delta_length) { + auto length_header = headers.getContentLengthValue(); + if (!length_header.empty()) { + int64_t old_length; + if (absl::SimpleAtoi(length_header, &old_length)) { + if (old_length > 0 && old_length + delta_length >= 0) { + headers.setContentLength(old_length + delta_length); + } + } + } +} + } // namespace KusciaCommon } // namespace HttpFilters } // namespace Extensions diff --git a/kuscia/source/filters/http/kuscia_common/kuscia_header.h b/kuscia/source/filters/http/kuscia_common/kuscia_header.h index 8aa959f..01b006f 100755 --- a/kuscia/source/filters/http/kuscia_common/kuscia_header.h +++ b/kuscia/source/filters/http/kuscia_common/kuscia_header.h @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include "envoy/http/header_map.h" +#include #include "re2/re2.h" @@ -34,20 +34,19 @@ const Http::LowerCaseString HeaderKeyKusciaToken("Kuscia-Token"); const Http::LowerCaseString HeaderKeyKusciaHost("Kuscia-Host"); const Http::LowerCaseString HeaderKeyOriginSource("Kuscia-Origin-Source"); - const Http::LowerCaseString HeaderKeyErrorMessage("Kuscia-Error-Message"); -const Http::LowerCaseString HeaderKeyFmtError("Kuscia-Error-Formatted"); const Http::LowerCaseString HeaderKeyErrorMessageInternal("Kuscia-Error-Message-Internal"); const Http::LowerCaseString HeaderKeyRecordBody("Kuscia-Record-Body"); const Http::LowerCaseString HeaderKeyEncryptVersion("Kuscia-Encrypt-Version"); const Http::LowerCaseString HeaderKeyEncryptIv("Kuscia-Encrypt-Iv"); -const Http::LowerCaseString HeaderKeyForwardRequestId("Kuscia-Foward-Request-Id"); +const Http::LowerCaseString HeaderTransitFlag("Kuscia-Transit-Flag"); +const Http::LowerCaseString HeaderTransitHash("Kuscia-Transit-Hash"); class KusciaHeader { - public: - static absl::optional getSource(const Http::RequestHeaderMap& headers); +public: + static absl::optional getSource(const Http::RequestHeaderMap& headers); }; // receiver.${peer}.svc/poll?timeout=xxx&service=xxx @@ -66,6 +65,8 @@ const std::string GatewayUnregisterPath("/svc/unregister"); const std::string InternalClusterHost("127.0.0.1:80"); +void adjustContentLength(Http::RequestOrResponseHeaderMap& headers, int64_t delta_length); + } // namespace KusciaCommon } // namespace HttpFilters } // namespace Extensions diff --git a/kuscia/source/filters/http/kuscia_crypt/config.cc b/kuscia/source/filters/http/kuscia_crypt/config.cc index 05702e1..9ed4a80 100755 --- a/kuscia/source/filters/http/kuscia_crypt/config.cc +++ b/kuscia/source/filters/http/kuscia_crypt/config.cc @@ -1,18 +1,17 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "kuscia/source/filters/http/kuscia_crypt/config.h" #include "envoy/registry/registry.h" @@ -26,16 +25,14 @@ namespace KusciaCrypt { Http::FilterFactoryCb CryptConfigFactory::createFilterFactoryFromProtoTyped( const envoy::extensions::filters::http::kuscia_crypt::v3::Crypt& proto_config, - const std::string&, - Server::Configuration::FactoryContext&) { - CryptConfigSharedPtr config = std::make_shared(proto_config); - return [config](Http::FilterChainFactoryCallbacks & callbacks) -> void { - callbacks.addStreamFilter(std::make_shared(config)); - }; + const std::string&, Server::Configuration::FactoryContext&) { + CryptConfigSharedPtr config = std::make_shared(proto_config); + return [config](Http::FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addStreamFilter(std::make_shared(config)); + }; } -REGISTER_FACTORY(CryptConfigFactory, - Server::Configuration::NamedHttpFilterConfigFactory); +REGISTER_FACTORY(CryptConfigFactory, Server::Configuration::NamedHttpFilterConfigFactory); } // namespace KusciaCrypt } // namespace HttpFilters diff --git a/kuscia/source/filters/http/kuscia_crypt/config.h b/kuscia/source/filters/http/kuscia_crypt/config.h index 33b956a..e4e2c82 100755 --- a/kuscia/source/filters/http/kuscia_crypt/config.h +++ b/kuscia/source/filters/http/kuscia_crypt/config.h @@ -1,18 +1,17 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include @@ -27,15 +26,14 @@ namespace Extensions { namespace HttpFilters { namespace KusciaCrypt { -class CryptConfigFactory : public Extensions::HttpFilters::Common::FactoryBase < - envoy::extensions::filters::http::kuscia_crypt::v3::Crypt > { - public: - CryptConfigFactory() : FactoryBase("envoy.filters.http.kuscia_crypt") {} +class CryptConfigFactory : public Extensions::HttpFilters::Common::FactoryBase< + envoy::extensions::filters::http::kuscia_crypt::v3::Crypt> { +public: + CryptConfigFactory() : FactoryBase("envoy.filters.http.kuscia_crypt") {} - Http::FilterFactoryCb createFilterFactoryFromProtoTyped( - const envoy::extensions::filters::http::kuscia_crypt::v3::Crypt&, - const std::string&, - Server::Configuration::FactoryContext&) override; + Http::FilterFactoryCb createFilterFactoryFromProtoTyped( + const envoy::extensions::filters::http::kuscia_crypt::v3::Crypt&, const std::string&, + Server::Configuration::FactoryContext&) override; }; } // namespace KusciaCrypt diff --git a/kuscia/source/filters/http/kuscia_crypt/crypt_filter.cc b/kuscia/source/filters/http/kuscia_crypt/crypt_filter.cc index 2f9f1a6..801ecfb 100755 --- a/kuscia/source/filters/http/kuscia_crypt/crypt_filter.cc +++ b/kuscia/source/filters/http/kuscia_crypt/crypt_filter.cc @@ -1,22 +1,21 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "kuscia/source/filters/http/kuscia_crypt/crypt_filter.h" - -#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" #include "kuscia/source/filters/http/kuscia_common/common.h" +#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" +#include namespace Envoy { namespace Extensions { @@ -24,185 +23,215 @@ namespace HttpFilters { namespace KusciaCrypt { static std::string getNamespaceFromHost(absl::string_view host) { - std::vector fields = absl::StrSplit(host, "."); - for (std::size_t i = 0; i < fields.size(); i++) { - if (fields[i] == "svc" && i > 0) { - return std::string(fields[i - 1]); - } + std::vector fields = absl::StrSplit(host, "."); + for (std::size_t i = 0; i < fields.size(); i++) { + if (fields[i] == "svc" && i > 0) { + return std::string(fields[i - 1]); } - return ""; + } + return ""; } Http::FilterHeadersStatus CryptFilter::decodeHeaders(Http::RequestHeaderMap& headers, bool end_stream) { - request_id_ = std::string(headers.getRequestIdValue()); + request_id_ = std::string(headers.getRequestIdValue()); - // create encpyter for internal to external request - if (config_->forwardEncryption()) { - createForwardCrypter(headers, end_stream); - } + // create encpyter for internal to external request + if (config_->forwardEncryption()) { + createForwardCrypter(headers, end_stream); + } - // create decrpter for external to internal request - if (config_->reverseEncryption()) { - createReverseCrypter(headers, end_stream); - } + // create decrpter for external to internal request + if (config_->reverseEncryption()) { + createReverseCrypter(headers, end_stream); + } - return Http::FilterHeadersStatus::Continue; + return Http::FilterHeadersStatus::Continue; } Http::FilterDataStatus CryptFilter::decodeData(Buffer::Instance& data, bool end_stream) { - // decrypt external to internal request body - if (reverse_crypter_) { - reverse_crypter_->decrypt(data, end_stream, left_data_); - ENVOY_LOG(debug, "decrypt request of {}, decrypted length: {}, remain length{}.", - request_id_, data.length(), left_data_.length()); + // decrypt external to internal request body + if (reverse_crypter_) { + if (!reverse_crypter_->decrypt(data, end_stream, left_data_)) { + sendCryptErrorReply(); + return Http::FilterDataStatus::StopIterationNoBuffer; } - - // encrpt request body from internal to external - if (forward_crypter_) { - forward_crypter_->encrypt(data, end_stream, left_data_); - ENVOY_LOG(debug, "encrypt request of {}, encrypted length: {}, remain length{}.", - request_id_, data.length(), left_data_.length()); - + ENVOY_LOG(debug, "decrypt request of {}, decrypted length: {}, remain length {}.", request_id_, + data.length(), left_data_.length()); + } + + // encrpt request body from internal to external + if (forward_crypter_) { + if (!forward_crypter_->encrypt(data, end_stream, left_data_)) { + sendCryptErrorReply(); + return Http::FilterDataStatus::StopIterationNoBuffer; } - return Http::FilterDataStatus::Continue; + ENVOY_LOG(debug, "encrypt request of {}, encrypted length: {}, remain length {}.", request_id_, + data.length(), left_data_.length()); + } + return Http::FilterDataStatus::Continue; } Http::FilterHeadersStatus CryptFilter::encodeHeaders(Http::ResponseHeaderMap& headers, bool end_stream) { - // resp_decrypter use same key with req_encrypter - if (config_->forwardEncryption()) { - checkRespDecrypt(headers, end_stream); - } + // resp_decrypter use same key with req_encrypter + if (config_->forwardEncryption()) { + checkRespDecrypt(headers, end_stream); + } - if (config_->reverseEncryption()) { - checkRespEncrypt(headers, end_stream); - } + if (config_->reverseEncryption()) { + checkRespEncrypt(headers, end_stream); + } - return Http::FilterHeadersStatus::Continue; + return Http::FilterHeadersStatus::Continue; } Http::FilterDataStatus CryptFilter::encodeData(Buffer::Instance& data, bool end_stream) { - // decrypt response from external to internal - if (forward_crypter_ && enable_resp_decrypt_) { - forward_crypter_->decrypt(data, end_stream, left_data_); - ENVOY_LOG(debug, "decrypt response of {}, decrypted length: {}, remain length{}.", - request_id_, data.length(), left_data_.length()); + // decrypt response from external to internal + if (forward_crypter_ && enable_resp_decrypt_) { + if (!forward_crypter_->decrypt(data, end_stream, left_data_)) { + sendCryptErrorReply(); + return Http::FilterDataStatus::StopIterationNoBuffer; } - - // encrypt response from internal to external - if (reverse_crypter_ && enable_resp_encrypt_) { - reverse_crypter_->encrypt(data, end_stream, left_data_); - ENVOY_LOG(debug, "encrypt response of {}, encrypted length: {}, remain length{}.", - request_id_, data.length(), left_data_.length()); + ENVOY_LOG(debug, "decrypt response of {}, decrypted length: {}, remain length {}.", + request_id_, data.length(), left_data_.length()); + } + + // encrypt response from internal to external + if (reverse_crypter_ && enable_resp_encrypt_) { + if (!reverse_crypter_->encrypt(data, end_stream, left_data_)) { + sendCryptErrorReply(); + return Http::FilterDataStatus::StopIterationNoBuffer; } - return Http::FilterDataStatus::Continue; + ENVOY_LOG(debug, "encrypt response of {}, encrypted length: {}, remain length {}.", + request_id_, data.length(), left_data_.length()); + } + return Http::FilterDataStatus::Continue; } -void CryptFilter::createForwardCrypter(Http::RequestHeaderMap& headers, - bool) { - std::string source = - std::string(headers.getByKey(KusciaCommon::HeaderKeyOriginSource).value_or(config_->selfNamespace())); - auto host = headers.getHostValue(); - std::string dest = getNamespaceFromHost(host); - const auto* rule = config_->getEncryptRule(source, dest); - KUSCIA_RETURN_IF(rule == nullptr); - forward_crypter_ = KusciaCrypter::createForwardCrypter(*rule, headers); - if (forward_crypter_ == nullptr) { - ENVOY_LOG(warn, "create forward encrypter failed. source{}, Dest{}", source, dest); - return; - } +void CryptFilter::createForwardCrypter(Http::RequestHeaderMap& headers, bool) { + std::string source; + auto sourceValue = headers.get(KusciaCommon::HeaderKeyOriginSource); + if (!sourceValue.empty()) { + source = std::string(sourceValue[0]->value().getStringView()); + }else { + source = std::string(config_->selfNamespace()); + } + + auto host = headers.getHostValue(); + std::string dest = getNamespaceFromHost(host); + const auto* rule = config_->getEncryptRule(source, dest); + KUSCIA_RETURN_IF(rule == nullptr); + forward_crypter_ = KusciaCrypter::createForwardCrypter(*rule, headers); + if (forward_crypter_ == nullptr) { + ENVOY_LOG(warn, "create forward encrypter failed. source {}, dest {}", source, dest); + return; + } + KusciaCommon::adjustContentLength(headers, +AES_GCM_TAG_LENGTH); } -void CryptFilter::createReverseCrypter(Http::RequestHeaderMap& headers, - bool) { - // get Rule - auto peer_host = headers.getByKey(KusciaCommon::HeaderKeyKusciaHost).value_or(std::string()); - std::string dest = getNamespaceFromHost(peer_host); - std::string source = - std::string(headers.getByKey(KusciaCommon::HeaderKeyOriginSource).value_or(std::string())); - KUSCIA_RETURN_IF(source.empty() || dest.empty()); - const auto* rule = config_->getDecryptRule(source, dest); - KUSCIA_RETURN_IF(rule == nullptr); - - // reverse_crypter_ need to be create even if end_stream == true, cause that req_decrypter is - // also resp_encrypter - reverse_crypter_ = KusciaCrypter::createReverseCrypter(*rule, headers); - if (reverse_crypter_ == nullptr) { - ENVOY_LOG(warn, "create reverse crypter failed. Source{}, Dest{}.", source, dest); - return; - } - - peer_origin_source_ = std::move(source); +void CryptFilter::createReverseCrypter(Http::RequestHeaderMap& headers, bool) { + // get Rule + absl::string_view peer_host; + auto host = headers.get(KusciaCommon::HeaderKeyKusciaHost); + if (!host.empty()){ + peer_host = host[0]->value().getStringView(); + } + std::string dest = getNamespaceFromHost(peer_host); + + std::string source; + auto value = headers.get(KusciaCommon::HeaderKeyOriginSource); + if (!value.empty()) { + source = std::string(value[0]->value().getStringView()); + } + + KUSCIA_RETURN_IF(source.empty() || dest.empty()); + const auto* rule = config_->getDecryptRule(source, dest); + KUSCIA_RETURN_IF(rule == nullptr); + + // reverse_crypter_ need to be create even if end_stream == true, cause that + // req_decrypter is also resp_encrypter + reverse_crypter_ = KusciaCrypter::createReverseCrypter(*rule, headers); + if (reverse_crypter_ == nullptr) { + ENVOY_LOG(warn, "create reverse crypter failed. source {}, dest {}.", source, dest); + return; + } + KusciaCommon::adjustContentLength(headers, -AES_GCM_TAG_LENGTH); + peer_origin_source_ = std::move(source); } void CryptFilter::checkRespEncrypt(Http::ResponseHeaderMap& headers, bool end_stream) { - enable_resp_encrypt_ = false; + enable_resp_encrypt_ = false; - KUSCIA_RETURN_IF(end_stream); - KUSCIA_RETURN_IF(!reverse_crypter_ || !reverse_crypter_->checkRespEncrypt(headers)); + KUSCIA_RETURN_IF(end_stream); + KUSCIA_RETURN_IF(!reverse_crypter_ || !reverse_crypter_->checkRespEncrypt(headers)); - headers.addCopy(KusciaCommon::HeaderKeyOriginSource, peer_origin_source_); + headers.addCopy(KusciaCommon::HeaderKeyOriginSource, peer_origin_source_); - enable_resp_encrypt_ = true; + enable_resp_encrypt_ = true; + KusciaCommon::adjustContentLength(headers, +AES_GCM_TAG_LENGTH); } void CryptFilter::checkRespDecrypt(Http::ResponseHeaderMap& headers, bool end_stream) { - enable_resp_decrypt_ = false; + enable_resp_decrypt_ = false; - // check origin namespace - auto origin_source = headers.get(KusciaCommon::HeaderKeyOriginSource); - KUSCIA_RETURN_IF(origin_source.size() != 1 || origin_source[0] == nullptr || - origin_source[0]->value() != config_->selfNamespace()); + // check origin namespace + auto origin_source = headers.get(KusciaCommon::HeaderKeyOriginSource); + KUSCIA_RETURN_IF(origin_source.size() != 1 || origin_source[0] == nullptr || + origin_source[0]->value() != config_->selfNamespace()); - KUSCIA_RETURN_IF(end_stream); - KUSCIA_RETURN_IF(!forward_crypter_ || !forward_crypter_->checkRespDecrypt(headers)); + KUSCIA_RETURN_IF(end_stream); + KUSCIA_RETURN_IF(!forward_crypter_ || !forward_crypter_->checkRespDecrypt(headers)); - enable_resp_decrypt_ = true; + enable_resp_decrypt_ = true; + KusciaCommon::adjustContentLength(headers, -AES_GCM_TAG_LENGTH); } -KusciaCryptConfig::KusciaCryptConfig(const CryptPbConfig& config) { - namespace_ = config.self_namespace(); - for (const auto& rule : config.encrypt_rules()) { - encrypt_rules_.emplace(std::make_pair(rule.source(), rule.destination()), rule); - } - for (const auto& rule : config.decrypt_rules()) { - decrypt_rules_.emplace(std::make_pair(rule.source(), rule.destination()), rule); - } +void CryptFilter::sendCryptErrorReply() { + decoder_callbacks_->sendLocalReply( + Http::Code::InternalServerError, "Internal Server Error(Crypt Failed)", + [](Http::ResponseHeaderMap& response_headers) { + response_headers.setReferenceKey(Envoy::Http::LowerCaseString("content-type"), + "text/plain"); + }, + absl::nullopt, ""); } -const std::string& KusciaCryptConfig::selfNamespace() const { - return namespace_; +KusciaCryptConfig::KusciaCryptConfig(const CryptPbConfig& config) { + namespace_ = config.self_namespace(); + for (const auto& rule : config.encrypt_rules()) { + encrypt_rules_.emplace(std::make_pair(rule.source(), rule.destination()), rule); + } + for (const auto& rule : config.decrypt_rules()) { + decrypt_rules_.emplace(std::make_pair(rule.source(), rule.destination()), rule); + } } -bool KusciaCryptConfig::forwardEncryption() const { - return !encrypt_rules_.empty(); -} +const std::string& KusciaCryptConfig::selfNamespace() const { return namespace_; } -bool KusciaCryptConfig::reverseEncryption() const { - return !decrypt_rules_.empty(); -} +bool KusciaCryptConfig::forwardEncryption() const { return !encrypt_rules_.empty(); } + +bool KusciaCryptConfig::reverseEncryption() const { return !decrypt_rules_.empty(); } const CryptRule* KusciaCryptConfig::getEncryptRule(const std::string& source, const std::string& dest) const { - auto iter = encrypt_rules_.find(std::make_pair(source, dest)); - if (iter == encrypt_rules_.end()) { - return nullptr; - } - return &(iter->second); + auto iter = encrypt_rules_.find(std::make_pair(source, dest)); + if (iter == encrypt_rules_.end()) { + return nullptr; + } + return &(iter->second); } const CryptRule* KusciaCryptConfig::getDecryptRule(const std::string& source, const std::string& dest) const { - auto iter = decrypt_rules_.find(std::make_pair(source, dest)); - if (iter == decrypt_rules_.end()) { - return nullptr; - } - return &(iter->second); + auto iter = decrypt_rules_.find(std::make_pair(source, dest)); + if (iter == decrypt_rules_.end()) { + return nullptr; + } + return &(iter->second); } } // namespace KusciaCrypt } // namespace HttpFilters } // namespace Extensions } // namespace Envoy - diff --git a/kuscia/source/filters/http/kuscia_crypt/crypt_filter.h b/kuscia/source/filters/http/kuscia_crypt/crypt_filter.h index e1119a6..af8382b 100755 --- a/kuscia/source/filters/http/kuscia_crypt/crypt_filter.h +++ b/kuscia/source/filters/http/kuscia_crypt/crypt_filter.h @@ -1,29 +1,26 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once -#include -#include - +#include "kuscia/api/filters/http/kuscia_crypt/v3/crypt.pb.h" +#include "kuscia/source/filters/http/kuscia_crypt/crypter.h" #include "source/common/buffer/buffer_impl.h" #include "source/common/common/logger.h" #include "source/extensions/filters/http/common/pass_through_filter.h" - -#include "kuscia/api/filters/http/kuscia_crypt/v3/crypt.pb.h" -#include "kuscia/source/filters/http/kuscia_crypt/crypter.h" +#include +#include namespace Envoy { namespace Extensions { @@ -37,57 +34,55 @@ using CryptConfigSharedPtr = std::shared_ptr; using CryptPbConfig = envoy::extensions::filters::http::kuscia_crypt::v3::Crypt; class CryptFilter : public Envoy::Http::PassThroughFilter, - public Logger::Loggable { - public: - explicit CryptFilter(CryptConfigSharedPtr config) : - config_(config), - request_id_(), - peer_origin_source_(), - enable_resp_encrypt_(false), + public Logger::Loggable { +public: + explicit CryptFilter(CryptConfigSharedPtr config) + : config_(config), request_id_(), peer_origin_source_(), enable_resp_encrypt_(false), enable_resp_decrypt_(false) {} - Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, - bool end_stream) override; - Http::FilterDataStatus decodeData(Buffer::Instance& data, bool end_stream) override; + Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, + bool end_stream) override; + Http::FilterDataStatus decodeData(Buffer::Instance& data, bool end_stream) override; - Http::FilterHeadersStatus encodeHeaders(Http::ResponseHeaderMap& headers, - bool end_stream) override; - Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override; + Http::FilterHeadersStatus encodeHeaders(Http::ResponseHeaderMap& headers, + bool end_stream) override; + Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override; - private: - void createForwardCrypter(Http::RequestHeaderMap& headers, bool end_stream); - void createReverseCrypter(Http::RequestHeaderMap& headers, bool end_stream); - void checkRespEncrypt(Http::ResponseHeaderMap& headers, bool end_stream); - void checkRespDecrypt(Http::ResponseHeaderMap& headers, bool end_stream); +private: + void createForwardCrypter(Http::RequestHeaderMap& headers, bool end_stream); + void createReverseCrypter(Http::RequestHeaderMap& headers, bool end_stream); + void checkRespEncrypt(Http::ResponseHeaderMap& headers, bool end_stream); + void checkRespDecrypt(Http::ResponseHeaderMap& headers, bool end_stream); + void sendCryptErrorReply(); - CryptConfigSharedPtr config_; - std::string request_id_; - std::string peer_origin_source_; + CryptConfigSharedPtr config_; + std::string request_id_; + std::string peer_origin_source_; - bool enable_resp_encrypt_; - bool enable_resp_decrypt_; - KusciaCrypterSharedPtr forward_crypter_; - KusciaCrypterSharedPtr reverse_crypter_; + bool enable_resp_encrypt_; + bool enable_resp_decrypt_; + KusciaCrypterSharedPtr forward_crypter_; + KusciaCrypterSharedPtr reverse_crypter_; - Buffer::OwnedImpl left_data_; + Buffer::OwnedImpl left_data_; - friend class CryptFilterTest; + friend class CryptFilterTest; }; class KusciaCryptConfig { - public: - explicit KusciaCryptConfig(const CryptPbConfig& config); - - const std::string& selfNamespace() const; - bool forwardEncryption() const; - bool reverseEncryption() const; - const CryptRule* getEncryptRule(const std::string& source, const std::string& dest) const; - const CryptRule* getDecryptRule(const std::string& source, const std::string& dest) const; - - private: - std::string namespace_; - std::map, CryptRule> encrypt_rules_; - std::map, CryptRule> decrypt_rules_; +public: + explicit KusciaCryptConfig(const CryptPbConfig& config); + + const std::string& selfNamespace() const; + bool forwardEncryption() const; + bool reverseEncryption() const; + const CryptRule* getEncryptRule(const std::string& source, const std::string& dest) const; + const CryptRule* getDecryptRule(const std::string& source, const std::string& dest) const; + +private: + std::string namespace_; + std::map, CryptRule> encrypt_rules_; + std::map, CryptRule> decrypt_rules_; }; } // namespace KusciaCrypt diff --git a/kuscia/source/filters/http/kuscia_crypt/crypter.cc b/kuscia/source/filters/http/kuscia_crypt/crypter.cc index d9699b1..662594b 100755 --- a/kuscia/source/filters/http/kuscia_crypt/crypter.cc +++ b/kuscia/source/filters/http/kuscia_crypt/crypter.cc @@ -1,32 +1,27 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "kuscia/source/filters/http/kuscia_crypt/crypter.h" - -#include - #include "fmt/format.h" +#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" #include "openssl/aes.h" #include "openssl/crypto.h" #include "openssl/evp.h" #include "openssl/rand.h" - -#include "source/common/common/base64.h" #include "source/common/common/assert.h" - -#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" +#include "source/common/common/base64.h" +#include namespace Envoy { namespace Extensions { @@ -36,158 +31,231 @@ namespace KusciaCrypt { static const std::string AlgorithmAES("AES"); static constexpr uint32_t AESEncryptBlockSize = 1024; - class AESCrypter : public KusciaCrypter { - public: - AESCrypter(const std::string& secret_key, const std::string& version, const std::string& iv) : - KusciaCrypter(secret_key, version), - iv_(iv) {} - - AESCrypter(std::string&& secret_key, std::string&& version, std::string&& iv): - KusciaCrypter(std::move(secret_key), std::move(version)), - iv_(std::move(iv)) {} - - virtual ~AESCrypter() {} - - static KusciaCrypterSharedPtr createForwardAESCrypter(std::string&& key, std::string&& version, - Http::RequestHeaderMap& headers); - - static KusciaCrypterSharedPtr createReverseAESCrypter(std::string&& key, std::string&& version, - Http::RequestHeaderMap& headers); - - bool encrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) override; - bool decrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) override; - - bool checkRespEncrypt(Http::ResponseHeaderMap& headers) override; - bool checkRespDecrypt(Http::ResponseHeaderMap& headers) override; - - private: - bool doCrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data); - - const std::string iv_; -}; - -KusciaCrypterSharedPtr KusciaCrypter::createForwardCrypter(const CryptRule& rule, - Http::RequestHeaderMap& headers) { - auto encrypt_version = - headers.getByKey(KusciaCommon::HeaderKeyEncryptVersion).value_or(std::string()); - if (!encrypt_version.empty()) { - return KusciaCrypterSharedPtr(); +public: + AESCrypter(std::string&& secret_key, std::string&& version, std::string&& iv) + : KusciaCrypter(std::move(secret_key), std::move(version)), ctx_enc_(EVP_CIPHER_CTX_new()), + ctx_dec_(EVP_CIPHER_CTX_new()), iv_(std::move(iv)), enc_init_(false), dec_init_(false) {} + + ~AESCrypter() { + if (ctx_enc_) { + EVP_CIPHER_CTX_free(ctx_enc_); } - headers.addCopy(KusciaCommon::HeaderKeyEncryptVersion, rule.secret_key_version()); - - if (rule.algorithm() == AlgorithmAES) { - return AESCrypter::createForwardAESCrypter(std::string(rule.secret_key()), - std::string(rule.secret_key_version()), - headers); + if (ctx_dec_) { + EVP_CIPHER_CTX_free(ctx_dec_); } - return KusciaCrypterSharedPtr(); -} + } -KusciaCrypterSharedPtr KusciaCrypter::createReverseCrypter(const CryptRule& rule, - Http::RequestHeaderMap& headers) { - std::string encrypt_version = - std::string(headers.getByKey(KusciaCommon::HeaderKeyEncryptVersion).value_or("")); - if (encrypt_version.empty()) { - return KusciaCrypterSharedPtr(); - } - - std::string secret_key; - if (encrypt_version == rule.secret_key_version()) { - secret_key = rule.secret_key(); - } else if (encrypt_version == rule.reserve_key_version()) { - secret_key = rule.reserve_key(); - } else { - ENVOY_LOG(warn, "unknown secret key version {}", encrypt_version); - return KusciaCrypterSharedPtr(); - } - - if (rule.algorithm() == AlgorithmAES) { - return AESCrypter::createReverseAESCrypter(std::move(secret_key), std::move(encrypt_version), headers); - } - return KusciaCrypterSharedPtr(); -} - -KusciaCrypterSharedPtr AESCrypter::createForwardAESCrypter(std::string&& key, std::string&& version, - Http::RequestHeaderMap& headers) { - std::string iv(AES_BLOCK_SIZE, '\0'); + static KusciaCrypterSharedPtr createForwardAESCrypter(std::string&& key, std::string&& version, + Http::RequestHeaderMap& headers) { + std::string iv(AES_GCM_IV_LENGTH, '\0'); int rc = RAND_bytes(reinterpret_cast(iv.data()), iv.length()); ASSERT(rc); std::string encoded_iv = Base64::encode(iv.data(), iv.length()); headers.addCopy(KusciaCommon::HeaderKeyEncryptIv, encoded_iv); return std::make_shared(std::move(key), std::move(version), std::move(iv)); -} + } -KusciaCrypterSharedPtr AESCrypter::createReverseAESCrypter(std::string&& key, std::string&& version, - Http::RequestHeaderMap& headers) { - auto encoded_iv = headers.getByKey(KusciaCommon::HeaderKeyEncryptIv); - if (!encoded_iv.has_value()) { - return KusciaCrypterSharedPtr(); + static KusciaCrypterSharedPtr createReverseAESCrypter(std::string&& key, std::string&& version, + Http::RequestHeaderMap& headers) { + auto encoded_iv = headers.get(KusciaCommon::HeaderKeyEncryptIv); + if (encoded_iv.empty()) { + return KusciaCrypterSharedPtr(); } - std::string iv = Base64::decode(encoded_iv.value()); + std::string iv = Base64::decode(encoded_iv[0]->value().getStringView()); return std::make_shared(std::move(key), std::move(version), std::move(iv)); -} + } -bool AESCrypter::encrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) { - return doCrypt(data, end_stream, left_data); -} + bool checkRespEncrypt(Http::ResponseHeaderMap& headers) override { + headers.addCopy(KusciaCommon::HeaderKeyEncryptVersion, version_); + return true; + } -bool AESCrypter::decrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) { - return doCrypt(data, end_stream, left_data); -} + bool checkRespDecrypt(Http::ResponseHeaderMap& headers) override { + auto result = headers.get(KusciaCommon::HeaderKeyEncryptVersion); + if (result.size() != 1 || result[0] == nullptr || result[0]->value().empty()) { + return false; + } + return true; + } -bool AESCrypter::doCrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) { - AES_KEY aes_key; - if (AES_set_encrypt_key(reinterpret_cast(secret_key_.data()), - 128, &aes_key) != 0) { - return false; + bool encrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) override { + if (data.length() > 0) { + left_data.move(data); } - if (left_data.length() > 0) { - data.prepend(left_data); // left_data is automatically drained after prepend + // Wait for more data + if (!end_stream && left_data.length() < AESEncryptBlockSize) { + return true; } - auto remain_length = data.length(); - auto output = std::make_unique(AESEncryptBlockSize); - auto crypt = [&](uint32_t len) { - // iv is changed after encrypt, so initialize it every time - unsigned char iv[AES_BLOCK_SIZE] {}; - std::memcpy(iv, iv_.data(), std::min(AES_BLOCK_SIZE, int(iv_.length()))); - int num = 0; - - // encrypts and decrypts are the same with OFB mode - AES_ofb128_encrypt(reinterpret_cast(data.linearize(len)), - output.get(), len, &aes_key, iv, &num); - data.drain(len); - data.add(output.get(), len); - return len; - }; - - while (remain_length >= AESEncryptBlockSize) { - remain_length -= crypt(AESEncryptBlockSize); + int out_len = 0; + size_t data_len = left_data.length(); + std::vector buffer(data_len + EVP_CIPHER_block_size(EVP_aes_128_gcm())); + // Initialise key and IV + if (!enc_init_) { + enc_init_ = true; + if (1 != EVP_EncryptInit_ex(ctx_enc_, EVP_aes_128_gcm(), nullptr, + reinterpret_cast(secret_key_.c_str()), + reinterpret_cast(iv_.c_str()))) { + ENVOY_LOG(warn, "Failed to init encrypt context."); + return false; + } + } + /* + * Provide the message to be encrypted, and obtain the encrypted output. + * EVP_EncryptUpdate can be called multiple times if necessary + */ + if (data_len > 0) { + if (1 != + EVP_EncryptUpdate(ctx_enc_, buffer.data(), &out_len, + reinterpret_cast(left_data.linearize(data_len)), + data_len)) { + ENVOY_LOG(warn, "Failed to encrypt data."); + return false; + } } - if (remain_length > 0) { - if (end_stream) { - crypt(remain_length); - } else { - left_data.add(data.linearize(remain_length), remain_length); - data.drain(remain_length); - } + data.add(buffer.data(), out_len); + left_data.drain(data_len); + + if (end_stream) { + /* + * Finalise the encryption. Normally ciphertext bytes may be written at + * this stage, but this does not occur in GCM mode + */ + if (1 != EVP_EncryptFinal_ex(ctx_enc_, buffer.data() + out_len, &out_len)) { + ENVOY_LOG(warn, "Failed to finalize encrypt."); + return false; + } + // Get the tag + std::vector tag(AES_GCM_TAG_LENGTH); + if (1 != + EVP_CIPHER_CTX_ctrl(ctx_enc_, EVP_CTRL_GCM_GET_TAG, AES_GCM_TAG_LENGTH, tag.data())) { + ENVOY_LOG(warn, "Failed to get tag."); + return false; + } + data.add(tag.data(), AES_GCM_TAG_LENGTH); } - return true; -} -bool AESCrypter::checkRespEncrypt(Http::ResponseHeaderMap& headers) { - headers.addCopy(KusciaCommon::HeaderKeyEncryptVersion, version_); return true; -} + } -bool AESCrypter::checkRespDecrypt(Http::ResponseHeaderMap& headers) { - auto result = headers.get(KusciaCommon::HeaderKeyEncryptVersion); - if (result.size() != 1 || result[0] == nullptr || result[0]->value().empty()) { + bool decrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) override { + if (data.length() > 0) { + left_data.move(data); + } + /* + * 1. If end_stream and length < AES_GCM_TAG_LENGTH, return false, which means corrupted data + * 2. If not end_stream and length < AESEncryptBlockSize, return true, which means wait for + * more data + */ + if (end_stream) { + if (left_data.length() < AES_GCM_TAG_LENGTH) { + ENVOY_LOG(warn, "Data corrupted, length < tag."); return false; + } + } else if (left_data.length() < AESEncryptBlockSize) { + return true; + } + int out_len = 0; + size_t data_len = left_data.length() - AES_GCM_TAG_LENGTH; + std::vector buffer(data_len); + // Initialise key and IV + if (!dec_init_) { + dec_init_ = true; + if (1 != EVP_DecryptInit_ex(ctx_dec_, EVP_aes_128_gcm(), nullptr, + reinterpret_cast(secret_key_.c_str()), + reinterpret_cast(iv_.c_str()))) { + ENVOY_LOG(warn, "Failed to init decrypt context."); + return false; + } + } + /* + * Provide the message to be decrypted, and obtain the plaintext output. + * EVP_DecryptUpdate can be called multiple times if necessary + */ + if (data_len > 0) { + if (1 != + EVP_DecryptUpdate(ctx_dec_, buffer.data(), &out_len, + reinterpret_cast(left_data.linearize(data_len)), + data_len)) { + ENVOY_LOG(warn, "Failed to dcrypt data."); + return false; + } + data.add(buffer.data(), out_len); + } + + left_data.drain(data_len); + + if (end_stream) { + // Set expected tag value. Works in OpenSSL 1.0.1d and later + std::vector tag(AES_GCM_TAG_LENGTH); + left_data.copyOut(0, AES_GCM_TAG_LENGTH, tag.data()); + left_data.drain(AES_GCM_TAG_LENGTH); + if (1 != + EVP_CIPHER_CTX_ctrl(ctx_dec_, EVP_CTRL_GCM_SET_TAG, AES_GCM_TAG_LENGTH, tag.data())) { + ENVOY_LOG(warn, "Failed to set tag."); + return false; + } + /* + * Finalise the decryption. A positive return value indicates success, + * anything else is a failure - the plaintext is not trustworthy. + */ + if (1 != EVP_DecryptFinal_ex(ctx_dec_, buffer.data() + out_len, &out_len)) { + ENVOY_LOG(warn, "Failed to finalize decrypt."); + return false; + } } return true; + } + +private: + EVP_CIPHER_CTX* ctx_enc_; + EVP_CIPHER_CTX* ctx_dec_; + std::string iv_; + bool enc_init_; + bool dec_init_; +}; + +KusciaCrypterSharedPtr KusciaCrypter::createForwardCrypter(const CryptRule& rule, + Http::RequestHeaderMap& headers) { + auto encrypt_version = headers.get(KusciaCommon::HeaderKeyEncryptVersion); + if (!encrypt_version.empty()) { + return KusciaCrypterSharedPtr(); + } + headers.addCopy(KusciaCommon::HeaderKeyEncryptVersion, rule.secret_key_version()); + + if (rule.algorithm() == AlgorithmAES) { + return AESCrypter::createForwardAESCrypter(std::string(rule.secret_key()), + std::string(rule.secret_key_version()), headers); + } + return KusciaCrypterSharedPtr(); +} + +KusciaCrypterSharedPtr KusciaCrypter::createReverseCrypter(const CryptRule& rule, + Http::RequestHeaderMap& headers) { + std::string encrypt_version; + auto value = headers.get(KusciaCommon::HeaderKeyEncryptVersion); + if (value.empty()) { + return KusciaCrypterSharedPtr(); + } + encrypt_version = std::string(value[0]->value().getStringView()); + + std::string secret_key; + if (encrypt_version == rule.secret_key_version()) { + secret_key = rule.secret_key(); + } else if (encrypt_version == rule.reserve_key_version()) { + secret_key = rule.reserve_key(); + } else { + ENVOY_LOG(warn, "unknown secret key version {}", encrypt_version); + return KusciaCrypterSharedPtr(); + } + + if (rule.algorithm() == AlgorithmAES) { + return AESCrypter::createReverseAESCrypter(std::move(secret_key), std::move(encrypt_version), + headers); + } + return KusciaCrypterSharedPtr(); } } // namespace KusciaCrypt diff --git a/kuscia/source/filters/http/kuscia_crypt/crypter.h b/kuscia/source/filters/http/kuscia_crypt/crypter.h index 3551699..89e9da9 100755 --- a/kuscia/source/filters/http/kuscia_crypt/crypter.h +++ b/kuscia/source/filters/http/kuscia_crypt/crypter.h @@ -1,18 +1,17 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include "source/common/buffer/buffer_impl.h" @@ -32,38 +31,37 @@ class KusciaCrypter; using KusciaCrypterSharedPtr = std::shared_ptr; using CryptRule = envoy::extensions::filters::http::kuscia_crypt::v3::CryptRule; +static const int AES_GCM_TAG_LENGTH = 16; // GCM recommend tag len +static const int AES_GCM_IV_LENGTH = 12; // GCM recommend iv len class KusciaCrypter : public Logger::Loggable { - public: - static KusciaCrypterSharedPtr createForwardCrypter(const CryptRule& rule, - Http::RequestHeaderMap& headers); - static KusciaCrypterSharedPtr createReverseCrypter(const CryptRule& rule, - Http::RequestHeaderMap& headers); +public: + static KusciaCrypterSharedPtr createForwardCrypter(const CryptRule& rule, + Http::RequestHeaderMap& headers); + static KusciaCrypterSharedPtr createReverseCrypter(const CryptRule& rule, + Http::RequestHeaderMap& headers); - KusciaCrypter(const std::string& secret_key, const std::string& version) : - secret_key_(secret_key), - version_(version) {} + KusciaCrypter(const std::string& secret_key, const std::string& version) + : secret_key_(secret_key), version_(version) {} - KusciaCrypter(std::string&& secret_key, std::string&& version) : - secret_key_(std::move(secret_key)), - version_(std::move(version)) {} + KusciaCrypter(std::string&& secret_key, std::string&& version) + : secret_key_(std::move(secret_key)), version_(std::move(version)) {} - virtual ~KusciaCrypter() {} - virtual bool encrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) = 0; - virtual bool decrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) = 0; + virtual ~KusciaCrypter() {} + virtual bool encrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) = 0; + virtual bool decrypt(Buffer::Instance& data, bool end_stream, Buffer::Instance& left_data) = 0; - virtual bool checkRespEncrypt(Http::ResponseHeaderMap& headers) = 0; - virtual bool checkRespDecrypt(Http::ResponseHeaderMap& headers) = 0; + virtual bool checkRespEncrypt(Http::ResponseHeaderMap& headers) = 0; + virtual bool checkRespDecrypt(Http::ResponseHeaderMap& headers) = 0; - protected: - const std::string secret_key_; - const std::string version_; +protected: + const std::string secret_key_; + const std::string version_; - friend class CryptFilterTest; + friend class CryptFilterTest; }; } // namespace KusciaCrypt } // namespace HttpFilters } // namespace Extensions } // namespace Envoy - diff --git a/kuscia/source/filters/http/kuscia_gress/BUILD b/kuscia/source/filters/http/kuscia_gress/BUILD index 80b97fb..0cc8360 100755 --- a/kuscia/source/filters/http/kuscia_gress/BUILD +++ b/kuscia/source/filters/http/kuscia_gress/BUILD @@ -18,7 +18,7 @@ envoy_cc_library( "@envoy//source/extensions/filters/http/common:pass_through_filter_lib", "@envoy//source/common/http:codes_lib", "@com_github_nlohmann_json//:json", - "@envoy//source/common/api:os_sys_calls_lib", + "@envoy//source/common/api:os_sys_calls_lib" ], ) diff --git a/kuscia/source/filters/http/kuscia_gress/config.cc b/kuscia/source/filters/http/kuscia_gress/config.cc index a5c140b..c1aefce 100755 --- a/kuscia/source/filters/http/kuscia_gress/config.cc +++ b/kuscia/source/filters/http/kuscia_gress/config.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "kuscia/source/filters/http/kuscia_gress/config.h" #include "envoy/registry/registry.h" @@ -26,16 +25,14 @@ namespace KusciaGress { Http::FilterFactoryCb GressConfigFactory::createFilterFactoryFromProtoTyped( const envoy::extensions::filters::http::kuscia_gress::v3::Gress& proto_config, - const std::string&, - Server::Configuration::FactoryContext&) { - GressFilterConfigSharedPtr config = std::make_shared(proto_config); - return [config](Http::FilterChainFactoryCallbacks & callbacks) -> void { - callbacks.addStreamFilter(std::make_shared(config)); - }; + const std::string&, Server::Configuration::FactoryContext&) { + GressFilterConfigSharedPtr config = std::make_shared(proto_config); + return [config](Http::FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addStreamFilter(std::make_shared(config)); + }; } -REGISTER_FACTORY(GressConfigFactory, - Server::Configuration::NamedHttpFilterConfigFactory); +REGISTER_FACTORY(GressConfigFactory, Server::Configuration::NamedHttpFilterConfigFactory); } // namespace KusciaGress } // namespace HttpFilters diff --git a/kuscia/source/filters/http/kuscia_gress/config.h b/kuscia/source/filters/http/kuscia_gress/config.h index ddcf317..1c61973 100755 --- a/kuscia/source/filters/http/kuscia_gress/config.h +++ b/kuscia/source/filters/http/kuscia_gress/config.h @@ -1,18 +1,17 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include "source/extensions/filters/http/common/factory_base.h" @@ -25,15 +24,14 @@ namespace Extensions { namespace HttpFilters { namespace KusciaGress { -class GressConfigFactory : public Extensions::HttpFilters::Common::FactoryBase < - envoy::extensions::filters::http::kuscia_gress::v3::Gress > { - public: - GressConfigFactory() : FactoryBase("envoy.filters.http.kuscia_gress") {} +class GressConfigFactory : public Extensions::HttpFilters::Common::FactoryBase< + envoy::extensions::filters::http::kuscia_gress::v3::Gress> { +public: + GressConfigFactory() : FactoryBase("envoy.filters.http.kuscia_gress") {} - Http::FilterFactoryCb createFilterFactoryFromProtoTyped( - const envoy::extensions::filters::http::kuscia_gress::v3::Gress&, - const std::string&, - Server::Configuration::FactoryContext&) override; + Http::FilterFactoryCb createFilterFactoryFromProtoTyped( + const envoy::extensions::filters::http::kuscia_gress::v3::Gress&, const std::string&, + Server::Configuration::FactoryContext&) override; }; } // namespace KusciaGress diff --git a/kuscia/source/filters/http/kuscia_gress/gress_filter.cc b/kuscia/source/filters/http/kuscia_gress/gress_filter.cc index cad96db..5c6dc5f 100755 --- a/kuscia/source/filters/http/kuscia_gress/gress_filter.cc +++ b/kuscia/source/filters/http/kuscia_gress/gress_filter.cc @@ -12,239 +12,277 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "kuscia/source/filters/http/kuscia_gress/gress_filter.h" - #include "fmt/format.h" +#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" +#include "source/common/http/codes.h" #include "source/common/http/header_utility.h" #include "source/common/http/headers.h" - -#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" +#include +#include +#include +#include namespace Envoy { namespace Extensions { namespace HttpFilters { namespace KusciaGress { -static void adjustContentLength(Http::RequestOrResponseHeaderMap& headers, uint64_t delta_length) { - auto length_header = headers.getContentLengthValue(); - if (!length_header.empty()) { - uint64_t old_length; - if (absl::SimpleAtoi(length_header, &old_length)) { - if (old_length != 0) { - headers.setContentLength(old_length + delta_length); - } - } +static std::string replaceNamespaceInHost(absl::string_view host, + absl::string_view new_namespace) { + std::vector fields = absl::StrSplit(host, "."); + for (std::size_t i = 2; i < fields.size(); i++) { + if (fields[i] == "svc") { + fields[i - 1] = new_namespace; + return absl::StrJoin(fields, "."); } + } + return ""; } -static std::string replaceNamespaceInHost(absl::string_view host, absl::string_view new_namespace) { - std::vector fields = absl::StrSplit(host, "."); - for (std::size_t i = 2; i < fields.size(); i++) { - if (fields[i] == "svc") { - fields[i - 1] = new_namespace; - return absl::StrJoin(fields, "."); - } - } - return ""; +static std::string getGatewayDesc(const std::string& domain, const std::string& instance, + const std::string& listener) { + return fmt::format("{}/{}/{}", domain, instance, listener); } -RewriteHostConfig::RewriteHostConfig(const RewriteHost& config) : - rewrite_policy_(config.rewrite_policy()), - header_(config.header()), - specified_host_(config.specified_host()) { - path_matchers_.reserve(config.path_matchers_size()); - for (const auto& pm : config.path_matchers()) { - PathMatcherConstSharedPtr matcher(new Envoy::Matchers::PathMatcher(pm)); - path_matchers_.emplace_back(matcher); - } +static std::string getListener(const StreamInfo::StreamInfo& stream_info) { + std::string address; + auto& provider = stream_info.downstreamAddressProvider(); + if (provider.localAddress() != nullptr) { + address = provider.localAddress()->asString(); + } + if (address.empty()) { + return "-"; + } + return absl::EndsWith(address, ":80") ? "internal" : "external"; } -GressFilterConfig::GressFilterConfig(const GressPbConfig& config) : - instance_(config.instance()), - self_namespace_(config.self_namespace()), - add_origin_source_(config.add_origin_source()), - max_logging_body_size_per_reqeuest_(config.max_logging_body_size_per_reqeuest()) { - rewrite_host_config_.reserve(config.rewrite_host_config_size()); - for (const auto& rh : config.rewrite_host_config()) { - rewrite_host_config_.emplace_back(RewriteHostConfig(rh)); - } +static std::string getCause(const StreamInfo::StreamInfo& stream_info) { + std::string cause; + if (stream_info.responseCodeDetails().has_value()) { + cause = stream_info.responseCodeDetails().value(); + } + return cause; } -Http::FilterHeadersStatus GressFilter::decodeHeaders(Http::RequestHeaderMap& headers, - bool) { - // store some useful headers - request_id_ = std::string(headers.getRequestIdValue()); - host_ = std::string(headers.getHostValue()); - auto record = headers.getByKey(KusciaCommon::HeaderKeyRecordBody); - if (record.has_value() && record.value() == "true") { - record_request_body_ = true; - record_response_body_ = true; - } +std::string strip(absl::string_view sv) { return std::string(sv.data(), sv.size()); } - // rewrite host to choose a new route - if (rewriteHost(headers)) { - decoder_callbacks_->downstreamCallbacks()->clearRouteCache(); - } else { - // replace ".svc:" with ".svc" for internal request - size_t n = host_.rfind(".svc:"); - if (n != std::string::npos) { - std::string substr = host_.substr(0, n + 4); - headers.setHost(substr); - decoder_callbacks_->downstreamCallbacks()->clearRouteCache(); - } +std::string getHeaderValue(const Http::ResponseHeaderMap& headers, + const Http::LowerCaseString& key) { + auto value_header = headers.get(key); + if (!value_header.empty() && value_header[0] != nullptr && !value_header[0]->value().empty()) { + return strip(value_header[0]->value().getStringView()); + } + return ""; +} + +RewriteHostConfig::RewriteHostConfig(const RewriteHost& config) + : rewrite_policy_(config.rewrite_policy()), header_(config.header()), + specified_host_(config.specified_host()) { + path_matchers_.reserve(config.path_matchers_size()); + for (const auto& pm : config.path_matchers()) { + PathMatcherConstSharedPtr matcher(new Envoy::Matchers::PathMatcher(pm)); + path_matchers_.emplace_back(matcher); + } +} + +GressFilterConfig::GressFilterConfig(const GressPbConfig& config) + : instance_(config.instance()), self_namespace_(config.self_namespace()), + add_origin_source_(config.add_origin_source()), + max_logging_body_size_per_reqeuest_(config.max_logging_body_size_per_reqeuest()) { + rewrite_host_config_.reserve(config.rewrite_host_config_size()); + for (const auto& rh : config.rewrite_host_config()) { + rewrite_host_config_.emplace_back(RewriteHostConfig(rh)); + } +} + +Http::FilterHeadersStatus GressFilter::decodeHeaders(Http::RequestHeaderMap& headers, bool) { + // store some useful headers + request_id_ = std::string(headers.getRequestIdValue()); + host_ = std::string(headers.getHostValue()); + auto record = headers.get(KusciaCommon::HeaderKeyRecordBody); + if (!record.empty() && std::string(record[0]->value().getStringView()) == "true") { + record_request_body_ = true; + record_response_body_ = true; + } + + // rewrite host to choose a new route + if (rewriteHost(headers)) { + decoder_callbacks_->downstreamCallbacks()->clearRouteCache(); + } else { + // replace ".svc:" with ".svc" for internal request + size_t n = host_.rfind(".svc:"); + if (n != std::string::npos) { + std::string substr = host_.substr(0, n + 4); + headers.setHost(substr); + decoder_callbacks_->downstreamCallbacks()->clearRouteCache(); } + } - // add origin-source if not exist - if (config_->addOriginSource()) { - auto origin_source = headers.getByKey(KusciaCommon::HeaderKeyOriginSource) - .value_or(std::string()); - if (origin_source.empty()) { - headers.addCopy(KusciaCommon::HeaderKeyOriginSource, config_->selfNamespace()); - } + // add origin-source if not exist + if (config_->addOriginSource()) { + auto origin_source = headers.get(KusciaCommon::HeaderKeyOriginSource); + if (origin_source.empty()) { + headers.addCopy(KusciaCommon::HeaderKeyOriginSource, config_->selfNamespace()); } + } - return Http::FilterHeadersStatus::Continue; + return Http::FilterHeadersStatus::Continue; } Http::FilterDataStatus GressFilter::decodeData(Buffer::Instance& data, bool end_stream) { - if (record_request_body_) { - record_request_body_ = recordBody(req_body_, data, end_stream, true); + if (record_request_body_) { + record_request_body_ = recordBody(req_body_, data, end_stream, true); + } + return Http::FilterDataStatus::Continue; +} + +Http::FilterHeadersStatus GressFilter::encodeHeaders(Http::ResponseHeaderMap& headers, bool) { + uint64_t status_code = 0; + if (absl::SimpleAtoi(headers.getStatusValue(), &status_code)) { + if (!(status_code >= 400 && status_code < 600)) { + return Http::FilterHeadersStatus::Continue; } - return Http::FilterDataStatus::Continue; -} - -Http::FilterHeadersStatus GressFilter::encodeHeaders(Http::ResponseHeaderMap& headers, - bool end_stream) { - // generate error msg - auto result = headers.get(KusciaCommon::HeaderKeyErrorMessage); - if (headers.getStatusValue() != "200") { - std::string err_msg; - auto result = headers.get(KusciaCommon::HeaderKeyErrorMessage); - if (result.empty()) { - auto inner_msg = headers.get(KusciaCommon::HeaderKeyErrorMessageInternal); - if (inner_msg.size() == 1 && inner_msg[0] != nullptr && !inner_msg[0]->value().empty()) { - err_msg = fmt::format("Domain {}.{}: {}", - config_->selfNamespace(), - config_->instance(), - inner_msg[0]->value().getStringView()); - headers.remove(KusciaCommon::HeaderKeyErrorMessageInternal); - } else { - err_msg = fmt::format("Domain {}.{}<--{} return http code {}.", - config_->selfNamespace(), - config_->instance(), - host_, - headers.getStatusValue()); - } - } else if (result[0] != nullptr) { - err_msg = fmt::format("Domain {}.{}<--{}", - config_->selfNamespace(), - config_->instance(), - result[0]->value().getStringView()); - - } - - headers.setCopy(KusciaCommon::HeaderKeyErrorMessage, err_msg); - if (end_stream) { - Envoy::Buffer::OwnedImpl body(err_msg); - adjustContentLength(headers, body.length()); - encoder_callbacks_->addEncodedData(body, true); - headers.setReferenceContentType(Http::Headers::get().ContentTypeValues.Text); - } + } + // 1. if error message key is set in response, then use it as error message + // 2. if internal error message key is set in response, then use it as error message + // 3. if neither of above, then use default error message + std::string error_message = getHeaderValue(headers, KusciaCommon::HeaderKeyErrorMessage); + bool formatted = false; + if (!error_message.empty()) { + formatted = true; + } else { + error_message = getHeaderValue(headers, KusciaCommon::HeaderKeyErrorMessageInternal); + if (error_message.empty()) { + error_message = Http::CodeUtility::toString(static_cast(status_code)); + } else { + headers.remove(KusciaCommon::HeaderKeyErrorMessageInternal); } - return Http::FilterHeadersStatus::Continue; + } + auto& stream_info = encoder_callbacks_->streamInfo(); + std::string rich_message = getRichMessage(stream_info, error_message, formatted); + headers.setCopy(KusciaCommon::HeaderKeyErrorMessage, rich_message); + return Http::FilterHeadersStatus::Continue; } Http::FilterDataStatus GressFilter::encodeData(Buffer::Instance& data, bool end_stream) { - if (record_response_body_) { - record_response_body_ = recordBody(resp_body_, data, end_stream, false); - } - return Http::FilterDataStatus::Continue; + if (record_response_body_) { + record_response_body_ = recordBody(resp_body_, data, end_stream, false); + } + return Http::FilterDataStatus::Continue; +} + +// The presence of trailers means the stream is ended, but encodeData() +// is never called with end_stream=true. +Http::FilterTrailersStatus GressFilter::encodeTrailers(Http::ResponseTrailerMap&) { + if (record_response_body_) { + Buffer::OwnedImpl data; + record_response_body_ = recordBody(resp_body_, data, true, false); + } + return Http::FilterTrailersStatus::Continue; +} + +std::string GressFilter::getRichMessage(const StreamInfo::StreamInfo& stream_info, + const std::string& error_message, bool formatted) { + std::string listener = getListener(stream_info); + std::string gateway_desc = + getGatewayDesc(config_->selfNamespace(), config_->instance(), listener); + std::string cause = getCause(stream_info); + std::string rich_message; + if (formatted) { + rich_message = fmt::format("<{}> => {}", gateway_desc, error_message); + } else if (cause == "via_upstream") { + rich_message = fmt::format("<{}> => ", gateway_desc, error_message); + } else { + rich_message = fmt::format("<{} ${}$ {}>", gateway_desc, cause, error_message); + } + return rich_message; } bool GressFilter::rewriteHost(Http::RequestHeaderMap& headers) { - for (const auto& rh : config_->rewriteHostConfig()) { - if (rewriteHost(headers, rh)) { - return true; - } + for (const auto& rh : config_->rewriteHostConfig()) { + if (rewriteHost(headers, rh)) { + return true; } - return false; + } + return false; } bool GressFilter::rewriteHost(Http::RequestHeaderMap& headers, const RewriteHostConfig& rh) { - auto header_value = headers.getByKey(Http::LowerCaseString(rh.header())).value_or(""); - if (header_value.empty()) { - return false; - } - - if (rh.pathMatchers().size() > 0) { - const absl::string_view path = headers.getPathValue(); - bool path_match = false; - for (const auto& pm : rh.pathMatchers()) { - if (pm->match(path)) { - path_match = true; - break; - } - } - if (!path_match) { - return false; - } - } + auto value = headers.get(Http::LowerCaseString(rh.header())); + if (value.empty()) { + return false; + } + absl::string_view header_value = value[0]->value().getStringView(); - switch (rh.rewritePolicy()) { - case RewriteHost::RewriteHostWithHeader: { - headers.setHost(header_value); - return true; - } - case RewriteHost::RewriteNamespaceWithHeader: { - auto host_value = replaceNamespaceInHost(headers.getHostValue(), header_value); - if (!host_value.empty()) { - headers.setHost(host_value); - return true; - } + if (rh.pathMatchers().size() > 0) { + const absl::string_view path = headers.getPathValue(); + bool path_match = false; + for (const auto& pm : rh.pathMatchers()) { + if (pm->match(path)) { + path_match = true; break; + } } - case RewriteHost::RewriteHostWithSpecifiedHost: { - if (!rh.specifiedHost().empty()) { - headers.setHost(rh.specifiedHost()); - return true; - } - break; + if (!path_match) { + return false; } - default: - break; + } + + switch (rh.rewritePolicy()) { + case RewriteHost::RewriteHostWithHeader: { + headers.setHost(header_value); + return true; + } + case RewriteHost::RewriteNamespaceWithHeader: { + auto host_value = replaceNamespaceInHost(headers.getHostValue(), header_value); + if (!host_value.empty()) { + headers.setHost(host_value); + return true; + } + break; + } + case RewriteHost::RewriteHostWithSpecifiedHost: { + if (!rh.specifiedHost().empty()) { + headers.setHost(rh.specifiedHost()); + return true; } + break; + } + default: + break; + } - return false; + return false; } -bool GressFilter::recordBody(Buffer::OwnedImpl& body, Buffer::Instance& data, - bool end_stream, bool is_req) { - auto& stream_info = is_req ? decoder_callbacks_->streamInfo() : encoder_callbacks_->streamInfo(); - std::string body_key = is_req ? "request_body" : "response_body"; - - uint64_t logging_size = static_cast(config_->maxLoggingBodySizePerReqeuest()); - bool record_body = true; - if (data.length() > 0) { - if (logging_size > 0 && body.length() + data.length() > logging_size) { - ENVOY_LOG(info, "{} of {} already larger than {}, stop logging", - body_key, request_id_, logging_size); - record_body = false; - Buffer::OwnedImpl empty_buffer{}; - empty_buffer.move(body); - } else { - body.add(data); - } - } +bool GressFilter::recordBody(Buffer::OwnedImpl& body, Buffer::Instance& data, bool end_stream, + bool is_req) { + auto& stream_info = is_req ? decoder_callbacks_->streamInfo() : encoder_callbacks_->streamInfo(); + std::string body_key = is_req ? "request_body" : "response_body"; - if (end_stream && body.length() > 0) { - ProtobufWkt::Value value; - value.set_string_value(body.toString()); - ProtobufWkt::Struct metadata; - (*metadata.mutable_fields())[body_key] = value; - stream_info.setDynamicMetadata("envoy.kuscia", metadata); + uint64_t logging_size = static_cast(config_->maxLoggingBodySizePerReqeuest()); + bool record_body = true; + if (data.length() > 0) { + if (logging_size > 0 && body.length() + data.length() > logging_size) { + ENVOY_LOG(info, "{} of {} already larger than {}, stop logging", body_key, request_id_, + logging_size); + record_body = false; + Buffer::OwnedImpl empty_buffer{}; + empty_buffer.move(body); + } else { + body.add(data); } - return record_body; + } + + if (end_stream && body.length() > 0) { + ProtobufWkt::Value value; + value.set_string_value(body.toString()); + ProtobufWkt::Struct metadata; + (*metadata.mutable_fields())[body_key] = value; + stream_info.setDynamicMetadata("envoy.kuscia", metadata); + } + return record_body; } } // namespace KusciaGress diff --git a/kuscia/source/filters/http/kuscia_gress/gress_filter.h b/kuscia/source/filters/http/kuscia_gress/gress_filter.h index f23c7fc..c6e2355 100755 --- a/kuscia/source/filters/http/kuscia_gress/gress_filter.h +++ b/kuscia/source/filters/http/kuscia_gress/gress_filter.h @@ -12,20 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. - #pragma once -#include -#include - +#include "envoy/common/matchers.h" +#include "include/nlohmann/json.hpp" +#include "kuscia/api/filters/http/kuscia_gress/v3/gress.pb.h" #include "source/common/buffer/buffer_impl.h" #include "source/common/common/logger.h" -#include "source/extensions/filters/http/common/pass_through_filter.h" - -#include "kuscia/api/filters/http/kuscia_gress/v3/gress.pb.h" - -#include "envoy/common/matchers.h" #include "source/common/common/matchers.h" +#include "source/extensions/filters/http/common/pass_through_filter.h" +#include +#include +#include namespace Envoy { namespace Extensions { @@ -38,96 +36,79 @@ using RewritePolicy = RewriteHost::RewritePolicy; using PathMatcherConstSharedPtr = std::shared_ptr; class RewriteHostConfig { - public: - explicit RewriteHostConfig(const RewriteHost& config); - - const std::string& header() const { - return header_; - } - RewritePolicy rewritePolicy() const { - return rewrite_policy_; - } - const std::string& specifiedHost() const { - return specified_host_; - } - - const std::vector& pathMatchers() const { - return path_matchers_; - } - - private: - RewriteHost::RewritePolicy rewrite_policy_; - std::string header_; - std::string specified_host_; - std::vector path_matchers_; +public: + explicit RewriteHostConfig(const RewriteHost& config); + + const std::string& header() const { return header_; } + RewritePolicy rewritePolicy() const { return rewrite_policy_; } + const std::string& specifiedHost() const { return specified_host_; } + + const std::vector& pathMatchers() const { return path_matchers_; } + +private: + RewriteHost::RewritePolicy rewrite_policy_; + std::string header_; + std::string specified_host_; + std::vector path_matchers_; }; class GressFilterConfig { - public: - explicit GressFilterConfig(const GressPbConfig& config); - const std::string& instance() const { - return instance_; - } - - const std::string& selfNamespace() const { - return self_namespace_; - } - - bool addOriginSource() const { - return add_origin_source_; - } - - int32_t maxLoggingBodySizePerReqeuest() { - return max_logging_body_size_per_reqeuest_; - } - - const std::vector& rewriteHostConfig() const { - return rewrite_host_config_; - } - - private: - std::string instance_; - std::string self_namespace_; - bool add_origin_source_; - int32_t max_logging_body_size_per_reqeuest_; - - std::vector rewrite_host_config_; +public: + explicit GressFilterConfig(const GressPbConfig& config); + const std::string& instance() const { return instance_; } + + const std::string& selfNamespace() const { return self_namespace_; } + + bool addOriginSource() const { return add_origin_source_; } + + int32_t maxLoggingBodySizePerReqeuest() { return max_logging_body_size_per_reqeuest_; } + + const std::vector& rewriteHostConfig() const { return rewrite_host_config_; } + +private: + std::string instance_; + std::string self_namespace_; + bool add_origin_source_; + int32_t max_logging_body_size_per_reqeuest_; + + std::vector rewrite_host_config_; }; using GressFilterConfigSharedPtr = std::shared_ptr; - class GressFilter : public Envoy::Http::PassThroughFilter, - public Logger::Loggable { - public: - explicit GressFilter(GressFilterConfigSharedPtr config) : - config_(config), - host_(), - request_id_(), - record_request_body_(false), + public Logger::Loggable { +public: + explicit GressFilter(GressFilterConfigSharedPtr config) + : config_(config), host_(), request_id_(), record_request_body_(false), record_response_body_(false) {} - Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, - bool) override; - Http::FilterDataStatus decodeData(Buffer::Instance& data, bool end_stream) override; + Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, bool) override; + Http::FilterDataStatus decodeData(Buffer::Instance& data, bool end_stream) override; + + Http::FilterHeadersStatus encodeHeaders(Http::ResponseHeaderMap& headers, + bool end_stream) override; + Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override; + + Http::FilterTrailersStatus encodeTrailers(Http::ResponseTrailerMap& headers) override; + +private: + bool rewriteHost(Http::RequestHeaderMap& headers); + bool rewriteHost(Http::RequestHeaderMap& headers, const RewriteHostConfig& rh); + bool recordBody(Buffer::OwnedImpl& body, Buffer::Instance& data, bool end_stream, bool is_req); - Http::FilterHeadersStatus encodeHeaders(Http::ResponseHeaderMap& headers, - bool end_stream) override; - Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override; + std::string getRichMessage(const StreamInfo::StreamInfo& stream_info, + const std::string& error_message, bool formatted); - private: - bool rewriteHost(Http::RequestHeaderMap& headers); - bool rewriteHost(Http::RequestHeaderMap& headers, const RewriteHostConfig& rh); - bool recordBody(Buffer::OwnedImpl& body, Buffer::Instance& data, bool end_stream, bool is_req); + GressFilterConfigSharedPtr config_; + std::string host_; + std::string request_id_; - GressFilterConfigSharedPtr config_; - std::string host_; - std::string request_id_; + bool record_request_body_; + bool record_response_body_; - bool record_request_body_; - bool record_response_body_; - Buffer::OwnedImpl req_body_; - Buffer::OwnedImpl resp_body_; + Buffer::OwnedImpl req_body_; + Buffer::OwnedImpl resp_body_; }; } // namespace KusciaGress diff --git a/kuscia/source/filters/http/kuscia_header_decorator/config.cc b/kuscia/source/filters/http/kuscia_header_decorator/config.cc index cc2f06d..97d9c9a 100755 --- a/kuscia/source/filters/http/kuscia_header_decorator/config.cc +++ b/kuscia/source/filters/http/kuscia_header_decorator/config.cc @@ -1,18 +1,17 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "kuscia/source/filters/http/kuscia_header_decorator/config.h" #include "envoy/registry/registry.h" @@ -25,13 +24,13 @@ namespace HttpFilters { namespace KusciaHeaderDecorator { Http::FilterFactoryCb HeaderDecoratorConfigFactory::createFilterFactoryFromProtoTyped( - const envoy::extensions::filters::http::kuscia_header_decorator::v3::HeaderDecorator& proto_config, - const std::string&, - Server::Configuration::FactoryContext&) { + const envoy::extensions::filters::http::kuscia_header_decorator::v3::HeaderDecorator& + proto_config, + const std::string&, Server::Configuration::FactoryContext&) { - return [proto_config](Http::FilterChainFactoryCallbacks& callbacks) -> void { - callbacks.addStreamDecoderFilter(std::make_shared(proto_config)); - }; + return [proto_config](Http::FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addStreamDecoderFilter(std::make_shared(proto_config)); + }; } REGISTER_FACTORY(HeaderDecoratorConfigFactory, diff --git a/kuscia/source/filters/http/kuscia_header_decorator/config.h b/kuscia/source/filters/http/kuscia_header_decorator/config.h index 26afd0e..2d139d8 100755 --- a/kuscia/source/filters/http/kuscia_header_decorator/config.h +++ b/kuscia/source/filters/http/kuscia_header_decorator/config.h @@ -1,18 +1,17 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include @@ -27,15 +26,15 @@ namespace Extensions { namespace HttpFilters { namespace KusciaHeaderDecorator { -class HeaderDecoratorConfigFactory : public Extensions::HttpFilters::Common::FactoryBase < - envoy::extensions::filters::http::kuscia_header_decorator::v3::HeaderDecorator > { - public: - HeaderDecoratorConfigFactory() : FactoryBase("envoy.filters.http.kuscia_header_decorator") {} +class HeaderDecoratorConfigFactory + : public Extensions::HttpFilters::Common::FactoryBase< + envoy::extensions::filters::http::kuscia_header_decorator::v3::HeaderDecorator> { +public: + HeaderDecoratorConfigFactory() : FactoryBase("envoy.filters.http.kuscia_header_decorator") {} - Http::FilterFactoryCb createFilterFactoryFromProtoTyped( - const envoy::extensions::filters::http::kuscia_header_decorator::v3::HeaderDecorator&, - const std::string&, - Server::Configuration::FactoryContext&) override; + Http::FilterFactoryCb createFilterFactoryFromProtoTyped( + const envoy::extensions::filters::http::kuscia_header_decorator::v3::HeaderDecorator&, + const std::string&, Server::Configuration::FactoryContext&) override; }; } // namespace KusciaHeaderDecorator diff --git a/kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.cc b/kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.cc index 27dc4ec..f50ba4a 100755 --- a/kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.cc +++ b/kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.h" #include "source/common/common/empty_string.h" @@ -28,34 +27,32 @@ namespace KusciaHeaderDecorator { using KusciaHeader = Envoy::Extensions::HttpFilters::KusciaCommon::KusciaHeader; HeaderDecoratorFilter::HeaderDecoratorFilter(const HeaderDecoratorPbConfig& config) { - for (const auto& source_headers : config.append_headers()) { - std::vector> headers; - headers.reserve(source_headers.headers_size()); - for (const auto& entry : source_headers.headers()) { - headers.emplace_back(entry.key(), entry.value()); - } - append_headers_.emplace(source_headers.source(), headers); + for (const auto& source_headers : config.append_headers()) { + std::vector> headers; + headers.reserve(source_headers.headers_size()); + for (const auto& entry : source_headers.headers()) { + headers.emplace_back(entry.key(), entry.value()); } + append_headers_.emplace(source_headers.source(), headers); + } } Http::FilterHeadersStatus HeaderDecoratorFilter::decodeHeaders(Http::RequestHeaderMap& headers, - bool) { - appendHeaders(headers); - return Http::FilterHeadersStatus::Continue; + bool) { + appendHeaders(headers); + return Http::FilterHeadersStatus::Continue; } void HeaderDecoratorFilter::appendHeaders(Http::RequestHeaderMap& headers) const { - auto source = KusciaHeader::getSource(headers).value_or(""); - auto iter = append_headers_.find(source); - if (iter != append_headers_.end()) { - for (const auto& entry : iter->second) { - headers.addCopy(Http::LowerCaseString(entry.first), entry.second); - } + auto source = KusciaHeader::getSource(headers).value_or(""); + auto iter = append_headers_.find(source); + if (iter != append_headers_.end()) { + for (const auto& entry : iter->second) { + headers.addCopy(Http::LowerCaseString(entry.first), entry.second); } + } } - - } // namespace KusciaHeaderDecorator } // namespace HttpFilters } // namespace Extensions diff --git a/kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.h b/kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.h index b344a75..4195209 100755 --- a/kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.h +++ b/kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.h @@ -1,18 +1,17 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include @@ -29,25 +28,24 @@ namespace HttpFilters { namespace KusciaHeaderDecorator { class HeaderDecoratorConfig; -using HeaderDecoratorPbConfig = envoy::extensions::filters::http::kuscia_header_decorator::v3::HeaderDecorator; +using HeaderDecoratorPbConfig = + envoy::extensions::filters::http::kuscia_header_decorator::v3::HeaderDecorator; class HeaderDecoratorFilter : public Http::PassThroughDecoderFilter, - public Logger::Loggable { - public: - explicit HeaderDecoratorFilter(const HeaderDecoratorPbConfig& config); - - Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, - bool) override; + public Logger::Loggable { +public: + explicit HeaderDecoratorFilter(const HeaderDecoratorPbConfig& config); - private: - void appendHeaders(Http::RequestHeaderMap& headers) const; + Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, bool) override; - std::map>, std::less<>> append_headers_; +private: + void appendHeaders(Http::RequestHeaderMap& headers) const; + std::map>, std::less<>> + append_headers_; }; } // namespace KusciaHeaderDecorator } // namespace HttpFilters } // namespace Extensions } // namespace Envoy - diff --git a/kuscia/source/filters/http/kuscia_poller/callbacks.cc b/kuscia/source/filters/http/kuscia_poller/callbacks.cc index 0be7208..5ddde9b 100644 --- a/kuscia/source/filters/http/kuscia_poller/callbacks.cc +++ b/kuscia/source/filters/http/kuscia_poller/callbacks.cc @@ -13,232 +13,247 @@ // limitations under the License. #include "kuscia/source/filters/http/kuscia_poller/callbacks.h" -#include "source/common/http/utility.h" -#include "source/common/http/message_impl.h" -#include "source/common/http/headers.h" #include "callbacks.h" +#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" +#include "source/common/http/headers.h" +#include "source/common/http/message_impl.h" +#include "source/common/http/utility.h" namespace Envoy { namespace Extensions { namespace HttpFilters { namespace KusciaPoller { -bool replyToReceiver(const std::string& conn_id, Upstream::ClusterManager& cluster_manager, const std::string& msg_id, const std::string& host, const ResponseMessagePb& resp_msg_pb, int timeout, std::string &errmsg) -{ - // Ensure the existence of the target cluster - std::string cluster_name = "internal-cluster"; - Upstream::ThreadLocalCluster* cluster = cluster_manager.getThreadLocalCluster(cluster_name); - if (cluster == nullptr) { - errmsg = "cluster " + cluster_name + " not found"; - return false; - } - - // Get asynchronous HTTP client - Http::AsyncClient& client = cluster->httpAsyncClient(); - - // Construct request message - Http::RequestMessagePtr req_msg(new Http::RequestMessageImpl()); - req_msg->headers().setPath("/reply?msgid=" + msg_id); - req_msg->headers().setHost(host); - req_msg->headers().setReferenceMethod(Http::Headers::get().MethodValues.Post); - req_msg->headers().setReferenceContentType(Envoy::Http::Headers::get().ContentTypeValues.Protobuf); - - std::string serialized_data = resp_msg_pb.SerializeAsString(); - - req_msg->body().add(serialized_data.data(), serialized_data.size()); - - // Send asynchronous requests - ReceiverCallbacks* callbacks = new ReceiverCallbacks(conn_id, msg_id); - Envoy::Http::AsyncClient::RequestOptions options; - options.setTimeout(std::chrono::milliseconds(timeout * 1000)); - Http::AsyncClient::Request* request = client.send(std::move(req_msg), *callbacks, options); - if (request == nullptr) { - delete callbacks; - callbacks = nullptr; - errmsg = "can't create request"; - return false; - } - - return true; -} - -ApplicationCallbacks::~ApplicationCallbacks() -{ - ENVOY_LOG(debug, "[{}] ApplicationCallbacks destroyed, message id: {}", conn_id_, message_id_); -} - -void KusciaPoller::ApplicationCallbacks::onSuccess(const Http::AsyncClient::Request &, Http::ResponseMessagePtr &&response) -{ - replyToReceiverOnSuccess(std::move(response)); - delete this; -} - -void ApplicationCallbacks::replyToReceiverOnSuccess(Http::ResponseMessagePtr&& response) -{ - Http::ResponseHeaderMap& headers = response->headers(); - const uint64_t status_code = Http::Utility::getResponseStatus(headers); - if (status_code == 200) { - ENVOY_LOG(info, "[{}] Forward request message {} successully, status code: {}", conn_id_, message_id_, status_code); - } else { - ENVOY_LOG(warn, "[{}] Forward request message {} , status code: {}", conn_id_, message_id_, status_code); - } - - ResponseMessagePb resp_msg_pb; - headers.iterate([&resp_msg_pb](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate { - (*resp_msg_pb.mutable_headers())[std::string(header.key().getStringView())] = std::string(header.value().getStringView()); - return Envoy::Http::HeaderMap::Iterate::Continue; - }); - resp_msg_pb.set_status_code(status_code); - resp_msg_pb.set_body(response->body().toString()); - resp_msg_pb.set_end_stream(true); +bool replyToReceiver(const std::string& conn_id, const std::string& req_host, + Upstream::ClusterManager& cluster_manager, const std::string& msg_id, + const std::string& host, const ResponseMessagePb& resp_msg_pb, int timeout, + std::string& errmsg) { + // Ensure the existence of the target cluster + std::string cluster_name = "internal-cluster"; + Upstream::ThreadLocalCluster* cluster = cluster_manager.getThreadLocalCluster(cluster_name); + if (cluster == nullptr) { + errmsg = "cluster " + cluster_name + " not found"; + return false; + } + + // Get asynchronous HTTP client + Http::AsyncClient& client = cluster->httpAsyncClient(); + + // Construct request message + Http::RequestMessagePtr req_msg(new Http::RequestMessageImpl()); + req_msg->headers().setPath("/reply?msgid=" + msg_id); + req_msg->headers().setHost(host); + req_msg->headers().setReferenceMethod(Http::Headers::get().MethodValues.Post); + req_msg->headers().setReferenceContentType( + Envoy::Http::Headers::get().ContentTypeValues.Protobuf); + req_msg->headers().setReferenceKey(KusciaCommon::HeaderTransitHash, req_host); + + std::string serialized_data = resp_msg_pb.SerializeAsString(); + + req_msg->body().add(serialized_data.data(), serialized_data.size()); + + // Send asynchronous requests + ReceiverCallbacks* callbacks = new ReceiverCallbacks(conn_id, msg_id); + Envoy::Http::AsyncClient::RequestOptions options; + options.setTimeout(std::chrono::milliseconds(timeout * 1000)); + Http::AsyncClient::Request* request = client.send(std::move(req_msg), *callbacks, options); + if (request == nullptr) { + delete callbacks; + callbacks = nullptr; + errmsg = "can't create request"; + return false; + } + + return true; +} - std::string errmsg; - if (!replyToReceiver(conn_id_, cluster_manager_, message_id_, receiver_host_, resp_msg_pb, rsp_timeout_, errmsg)) { - ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, errmsg); - } +ApplicationCallbacks::~ApplicationCallbacks() { + ENVOY_LOG(debug, "[{}] ApplicationCallbacks destroyed, message id: {}", conn_id_, message_id_); } -void ApplicationCallbacks::onFailure(const Http::AsyncClient::Request &, Http::AsyncClient::FailureReason) -{ - ENVOY_LOG(error, "[{}] Forward request message {} error: network error", conn_id_, message_id_); - replyToReceiverOnFailure(); - delete this; +void KusciaPoller::ApplicationCallbacks::onSuccess(const Http::AsyncClient::Request&, + Http::ResponseMessagePtr&& response) { + replyToReceiverOnSuccess(std::move(response)); + delete this; } -void ApplicationCallbacks::replyToReceiverOnFailure() -{ - ResponseMessagePb resp_msg_pb; - resp_msg_pb.set_status_code(502); - resp_msg_pb.set_end_stream(true); - std::string errmsg; - if (!replyToReceiver(conn_id_, cluster_manager_, message_id_, receiver_host_, resp_msg_pb, rsp_timeout_, errmsg)) { - ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, errmsg); - } -} - -KusciaPoller::ApiserverCallbacks::~ApiserverCallbacks() -{ - ENVOY_LOG(info, "[{}] ApiserverCallbacks destroyed, message id: {}", conn_id_, message_id_); -} - -void ApiserverCallbacks::onHeaders(Http::ResponseHeaderMapPtr &&headers, bool end_stream) -{ - ENVOY_LOG(info, "[{}] ApiserverCallbacks onHeaders, message id: {}", conn_id_, message_id_); - if (headers == nullptr) { - ENVOY_LOG(error, "[{}] Headers is null, message id: {}", conn_id_, message_id_); - return; - } - - const uint64_t status_code = Http::Utility::getResponseStatus(*headers); - if (status_code == 200) { - ENVOY_LOG(info, "[{}] Forward request message {} successully, status code: {}", conn_id_, message_id_, status_code); - } else { - ENVOY_LOG(warn, "[{}] Forward request message {} , status code: {}", conn_id_, message_id_, status_code); - } - - ResponseMessagePb resp_msg_pb; - headers->iterate([&resp_msg_pb](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate { - (*resp_msg_pb.mutable_headers())[std::string(header.key().getStringView())] = std::string(header.value().getStringView()); - return Envoy::Http::HeaderMap::Iterate::Continue; - }); - resp_msg_pb.set_status_code(status_code); - resp_msg_pb.set_chunk_data(true); - resp_msg_pb.set_index(index_++); - if (end_stream) { - resp_msg_pb.set_end_stream(true); - } - - std::string errmsg; - if (!replyToReceiver(conn_id_, cluster_manager_, message_id_, receiver_host_, resp_msg_pb, rsp_timeout_, errmsg)) { - ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, errmsg); - } -} - -void ApiserverCallbacks::onData(Buffer::Instance &data, bool end_stream) -{ - ENVOY_LOG(info, "[{}] ApiserverCallbacks onData, message id: {}, data len: {}", conn_id_, message_id_, data.length()); - ResponseMessagePb resp_msg_pb; - - resp_msg_pb.set_body(data.toString()); - resp_msg_pb.set_chunk_data(true); - resp_msg_pb.set_index(index_++); - if (end_stream) { - resp_msg_pb.set_end_stream(true); - } - - std::string errmsg; - if (!replyToReceiver(conn_id_, cluster_manager_, message_id_, receiver_host_, resp_msg_pb, rsp_timeout_, errmsg)) { - ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, errmsg); - } -} - -void ApiserverCallbacks::onTrailers(Http::ResponseTrailerMapPtr &&) -{ - ENVOY_LOG(info, "[{}] ApiserverCallbacks onTrailers, message id: {}", conn_id_, message_id_); - ResponseMessagePb resp_msg_pb; - resp_msg_pb.set_end_stream(true); - resp_msg_pb.set_chunk_data(true); - resp_msg_pb.set_index(index_++); +void ApplicationCallbacks::replyToReceiverOnSuccess(Http::ResponseMessagePtr&& response) { + Http::ResponseHeaderMap& headers = response->headers(); + const uint64_t status_code = Http::Utility::getResponseStatus(headers); + if (status_code == 200) { + ENVOY_LOG(info, "[{}] Forward request message {} successully, status code: {}", conn_id_, + message_id_, status_code); + } else { + ENVOY_LOG(warn, "[{}] Forward request message {} , status code: {}", conn_id_, message_id_, + status_code); + } + + ResponseMessagePb resp_msg_pb; + headers.iterate([&resp_msg_pb](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate { + (*resp_msg_pb.mutable_headers())[std::string(header.key().getStringView())] = + std::string(header.value().getStringView()); + return Envoy::Http::HeaderMap::Iterate::Continue; + }); + resp_msg_pb.set_status_code(status_code); + resp_msg_pb.set_body(response->body().toString()); + resp_msg_pb.set_end_stream(true); + + std::string errmsg; + if (!replyToReceiver(conn_id_, req_host_, cluster_manager_, message_id_, receiver_host_, + resp_msg_pb, rsp_timeout_, errmsg)) { + ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, + errmsg); + } +} - std::string errmsg; - if (!replyToReceiver(conn_id_, cluster_manager_, message_id_, receiver_host_, resp_msg_pb, rsp_timeout_, errmsg)) { - ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, errmsg); - } +void ApplicationCallbacks::onFailure(const Http::AsyncClient::Request&, + Http::AsyncClient::FailureReason) { + ENVOY_LOG(error, "[{}] Forward request message {} error: network error", conn_id_, message_id_); + replyToReceiverOnFailure(); + delete this; } -void ApiserverCallbacks::onReset() -{ - ENVOY_LOG(info, "[{}] ApiserverCallbacks onReset, message id: {}", conn_id_, message_id_); - ResponseMessagePb resp_msg_pb; - resp_msg_pb.set_status_code(503); - resp_msg_pb.set_end_stream(true); - resp_msg_pb.set_index(index_++); +void ApplicationCallbacks::replyToReceiverOnFailure() { + ResponseMessagePb resp_msg_pb; + resp_msg_pb.set_status_code(502); + resp_msg_pb.set_end_stream(true); + std::string errmsg; + if (!replyToReceiver(conn_id_, req_host_, cluster_manager_, message_id_, receiver_host_, + resp_msg_pb, rsp_timeout_, errmsg)) { + ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, + errmsg); + } +} - std::string errmsg; - if (!replyToReceiver(conn_id_, cluster_manager_, message_id_, receiver_host_, resp_msg_pb, rsp_timeout_, errmsg)) { - ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, errmsg); - } +KusciaPoller::ApiserverCallbacks::~ApiserverCallbacks() { + ENVOY_LOG(info, "[{}] ApiserverCallbacks destroyed, message id: {}", conn_id_, message_id_); +} - delete this; +void ApiserverCallbacks::onHeaders(Http::ResponseHeaderMapPtr&& headers, bool end_stream) { + ENVOY_LOG(info, "[{}] ApiserverCallbacks onHeaders, message id: {}", conn_id_, message_id_); + if (headers == nullptr) { + ENVOY_LOG(error, "[{}] Headers is null, message id: {}", conn_id_, message_id_); + return; + } + + const uint64_t status_code = Http::Utility::getResponseStatus(*headers); + if (status_code == 200) { + ENVOY_LOG(info, "[{}] Forward request message {} successully, status code: {}", conn_id_, + message_id_, status_code); + } else { + ENVOY_LOG(warn, "[{}] Forward request message {} , status code: {}", conn_id_, message_id_, + status_code); + } + + ResponseMessagePb resp_msg_pb; + headers->iterate([&resp_msg_pb](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate { + (*resp_msg_pb.mutable_headers())[std::string(header.key().getStringView())] = + std::string(header.value().getStringView()); + return Envoy::Http::HeaderMap::Iterate::Continue; + }); + resp_msg_pb.set_status_code(status_code); + resp_msg_pb.set_chunk_data(true); + resp_msg_pb.set_index(index_++); + if (end_stream) { + resp_msg_pb.set_end_stream(true); + } + + std::string errmsg; + if (!replyToReceiver(conn_id_, req_host_, cluster_manager_, message_id_, receiver_host_, + resp_msg_pb, rsp_timeout_, errmsg)) { + ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, + errmsg); + } } +void ApiserverCallbacks::onData(Buffer::Instance& data, bool end_stream) { + ENVOY_LOG(info, "[{}] ApiserverCallbacks onData, message id: {}, data len: {}", conn_id_, + message_id_, data.length()); + ResponseMessagePb resp_msg_pb; -void ApiserverCallbacks::onComplete() -{ - ENVOY_LOG(info, "[{}] ApiserverCallbacks onComplete, message id: {}", conn_id_, message_id_); + resp_msg_pb.set_body(data.toString()); + resp_msg_pb.set_chunk_data(true); + resp_msg_pb.set_index(index_++); + if (end_stream) { + resp_msg_pb.set_end_stream(true); + } + + std::string errmsg; + if (!replyToReceiver(conn_id_, req_host_, cluster_manager_, message_id_, receiver_host_, + resp_msg_pb, rsp_timeout_, errmsg)) { + ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, + errmsg); + } +} - delete this; +void ApiserverCallbacks::onTrailers(Http::ResponseTrailerMapPtr&&) { + ENVOY_LOG(info, "[{}] ApiserverCallbacks onTrailers, message id: {}", conn_id_, message_id_); + ResponseMessagePb resp_msg_pb; + resp_msg_pb.set_end_stream(true); + resp_msg_pb.set_chunk_data(true); + resp_msg_pb.set_index(index_++); + + std::string errmsg; + if (!replyToReceiver(conn_id_, req_host_, cluster_manager_, message_id_, receiver_host_, + resp_msg_pb, rsp_timeout_, errmsg)) { + ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, + errmsg); + } } -void ApiserverCallbacks::saveRequestMessage(Http::RequestMessagePtr &&req_message) -{ - req_message_ = std::move(req_message); +void ApiserverCallbacks::onReset() { + ENVOY_LOG(info, "[{}] ApiserverCallbacks onReset, message id: {}", conn_id_, message_id_); + ResponseMessagePb resp_msg_pb; + resp_msg_pb.set_status_code(503); + resp_msg_pb.set_end_stream(true); + resp_msg_pb.set_index(index_++); + + std::string errmsg; + if (!replyToReceiver(conn_id_, req_host_, cluster_manager_, message_id_, receiver_host_, + resp_msg_pb, rsp_timeout_, errmsg)) { + ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message_id_, + errmsg); + } + + delete this; } -ReceiverCallbacks::~ReceiverCallbacks() -{ - ENVOY_LOG(debug, "[{}] ReceiverCallbacks destroyed, message id: {}", conn_id_, msg_id_); +void ApiserverCallbacks::onComplete() { + ENVOY_LOG(info, "[{}] ApiserverCallbacks onComplete, message id: {}", conn_id_, message_id_); + + delete this; } -void ReceiverCallbacks::onSuccess(const Http::AsyncClient::Request &, Http::ResponseMessagePtr &&response) -{ - const uint64_t status_code = Http::Utility::getResponseStatus(response->headers()); - if (status_code == 200) { - ENVOY_LOG(info, "[{}] Forward response message {} successully, status code: {}", conn_id_, msg_id_, status_code); - } else { - ENVOY_LOG(warn, "[{}] Forward response message {} , status code: {}", conn_id_, msg_id_, status_code); - } +void ApiserverCallbacks::saveRequestMessage(Http::RequestMessagePtr&& req_message) { + req_message_ = std::move(req_message); +} - delete this; +ReceiverCallbacks::~ReceiverCallbacks() { + ENVOY_LOG(debug, "[{}] ReceiverCallbacks destroyed, message id: {}", conn_id_, msg_id_); } -void ReceiverCallbacks::onFailure(const Http::AsyncClient::Request &, Http::AsyncClient::FailureReason) -{ - ENVOY_LOG(error, "[{}] Forward response message {} error: {}", conn_id_, msg_id_, "network error"); - delete this; +void ReceiverCallbacks::onSuccess(const Http::AsyncClient::Request&, + Http::ResponseMessagePtr&& response) { + const uint64_t status_code = Http::Utility::getResponseStatus(response->headers()); + if (status_code == 200) { + ENVOY_LOG(info, "[{}] Forward response message {} successully, status code: {}", conn_id_, + msg_id_, status_code); + } else { + ENVOY_LOG(warn, "[{}] Forward response message {} , status code: {}", conn_id_, msg_id_, + status_code); + } + + delete this; } +void ReceiverCallbacks::onFailure(const Http::AsyncClient::Request&, + Http::AsyncClient::FailureReason) { + ENVOY_LOG(error, "[{}] Forward response message {} error: {}", conn_id_, msg_id_, + "network error"); + delete this; } + +} // namespace KusciaPoller // namespace KusciaPoller } // namespace HttpFilters } // namespace Extensions diff --git a/kuscia/source/filters/http/kuscia_poller/callbacks.h b/kuscia/source/filters/http/kuscia_poller/callbacks.h index b0743b9..fa820ea 100644 --- a/kuscia/source/filters/http/kuscia_poller/callbacks.h +++ b/kuscia/source/filters/http/kuscia_poller/callbacks.h @@ -14,77 +14,97 @@ #pragma once -#include "kuscia/source/filters/http/kuscia_poller/common.h" #include "envoy/http/async_client.h" -#include "source/common/common/logger.h" #include "envoy/upstream/cluster_manager.h" +#include "kuscia/source/filters/http/kuscia_poller/common.h" +#include "source/common/common/logger.h" namespace Envoy { namespace Extensions { namespace HttpFilters { namespace KusciaPoller { -extern bool replyToReceiver(const std::string& conn_id, Upstream::ClusterManager& cluster_manager, const std::string& msg_id, const std::string& host, const ResponseMessagePb& resp_msg_pb, int timeout, std::string &errmsg); +extern bool replyToReceiver(const std::string& conn_id, const std::string& req_host, + Upstream::ClusterManager& cluster_manager, const std::string& msg_id, + const std::string& host, const ResponseMessagePb& resp_msg_pb, + int timeout, std::string& errmsg); -class ApplicationCallbacks : public Http::AsyncClient::Callbacks, public Logger::Loggable { +class ApplicationCallbacks : public Http::AsyncClient::Callbacks, + public Logger::Loggable { public: - ApplicationCallbacks(const std::string& conn_id, Upstream::ClusterManager& cluster_manager, const std::string& message_id, const std::string& peer_receiver_host, int rsp_timeout) - : conn_id_(conn_id), cluster_manager_(cluster_manager), message_id_(message_id), receiver_host_(peer_receiver_host), rsp_timeout_(rsp_timeout) {} - ~ApplicationCallbacks(); + ApplicationCallbacks(const std::string& conn_id, const std::string& req_host, + Upstream::ClusterManager& cluster_manager, const std::string& message_id, + const std::string& peer_receiver_host, int rsp_timeout) + : conn_id_(conn_id), req_host_(req_host), cluster_manager_(cluster_manager), + message_id_(message_id), receiver_host_(peer_receiver_host), rsp_timeout_(rsp_timeout) {} + ~ApplicationCallbacks(); - void onSuccess(const Http::AsyncClient::Request& request, Http::ResponseMessagePtr&& response) override; - void onFailure(const Http::AsyncClient::Request& request, Http::AsyncClient::FailureReason reason) override; - void onBeforeFinalizeUpstreamSpan(Tracing::Span&, const Http::ResponseHeaderMap*) override {} - void replyToReceiverOnSuccess(Http::ResponseMessagePtr&& response); - void replyToReceiverOnFailure(); + void onSuccess(const Http::AsyncClient::Request& request, + Http::ResponseMessagePtr&& response) override; + void onFailure(const Http::AsyncClient::Request& request, + Http::AsyncClient::FailureReason reason) override; + void onBeforeFinalizeUpstreamSpan(Tracing::Span&, const Http::ResponseHeaderMap*) override {} + void replyToReceiverOnSuccess(Http::ResponseMessagePtr&& response); + void replyToReceiverOnFailure(); private: - std::string conn_id_; - Upstream::ClusterManager& cluster_manager_; - std::string message_id_; - std::string receiver_host_; - int rsp_timeout_; + std::string conn_id_; + std::string req_host_; + Upstream::ClusterManager& cluster_manager_; + std::string message_id_; + std::string receiver_host_; + int rsp_timeout_; }; -class ApiserverCallbacks : public Http::AsyncClient::StreamCallbacks, public Logger::Loggable { +class ApiserverCallbacks : public Http::AsyncClient::StreamCallbacks, + public Logger::Loggable { public: - ApiserverCallbacks(const std::string& conn_id, Upstream::ClusterManager& cluster_manager, const std::string& message_id, const std::string& peer_receiver_host, int rsp_timeout) - : conn_id_(conn_id), cluster_manager_(cluster_manager), message_id_(message_id), receiver_host_(peer_receiver_host), rsp_timeout_(rsp_timeout), index_(0) {} - ~ApiserverCallbacks(); + ApiserverCallbacks(const std::string& conn_id, const std::string& req_host, + Upstream::ClusterManager& cluster_manager, const std::string& message_id, + const std::string& peer_receiver_host, int rsp_timeout) + : conn_id_(conn_id), req_host_(req_host), cluster_manager_(cluster_manager), + message_id_(message_id), receiver_host_(peer_receiver_host), rsp_timeout_(rsp_timeout), + index_(0) {} + ~ApiserverCallbacks(); - void onHeaders(Http::ResponseHeaderMapPtr&& headers, bool end_stream) override; - void onData(Buffer::Instance& data, bool end_stream) override; - void onTrailers(Http::ResponseTrailerMapPtr&& trailers) override; - void onReset() override; - void onComplete() override; + void onHeaders(Http::ResponseHeaderMapPtr&& headers, bool end_stream) override; + void onData(Buffer::Instance& data, bool end_stream) override; + void onTrailers(Http::ResponseTrailerMapPtr&& trailers) override; + void onReset() override; + void onComplete() override; - void replyToReceiverOnSuccess(Http::ResponseMessagePtr&& response); - void replyToReceiverOnFailure(); + void replyToReceiverOnSuccess(Http::ResponseMessagePtr&& response); + void replyToReceiverOnFailure(); - void saveRequestMessage(Http::RequestMessagePtr&& req_message); + void saveRequestMessage(Http::RequestMessagePtr&& req_message); private: - std::string conn_id_; - // Http::RequestHeaderMapPtr headers_; - Http::RequestMessagePtr req_message_; - Upstream::ClusterManager& cluster_manager_; - std::string message_id_; - std::string receiver_host_; - int rsp_timeout_; - int index_; + std::string conn_id_; + std::string req_host_; + // Http::RequestHeaderMapPtr headers_; + Http::RequestMessagePtr req_message_; + Upstream::ClusterManager& cluster_manager_; + std::string message_id_; + std::string receiver_host_; + int rsp_timeout_; + int index_; }; -class ReceiverCallbacks : public Http::AsyncClient::Callbacks, public Logger::Loggable { +class ReceiverCallbacks : public Http::AsyncClient::Callbacks, + public Logger::Loggable { public: - ReceiverCallbacks(const std::string& conn_id, const std::string& msg_id) : conn_id_(conn_id), msg_id_(msg_id) {} - ~ReceiverCallbacks(); - void onSuccess(const Http::AsyncClient::Request& request, Http::ResponseMessagePtr&& response) override; - void onFailure(const Http::AsyncClient::Request& request, Http::AsyncClient::FailureReason reason) override; - void onBeforeFinalizeUpstreamSpan(Tracing::Span&, const Http::ResponseHeaderMap*) override {} + ReceiverCallbacks(const std::string& conn_id, const std::string& msg_id) + : conn_id_(conn_id), msg_id_(msg_id) {} + ~ReceiverCallbacks(); + void onSuccess(const Http::AsyncClient::Request& request, + Http::ResponseMessagePtr&& response) override; + void onFailure(const Http::AsyncClient::Request& request, + Http::AsyncClient::FailureReason reason) override; + void onBeforeFinalizeUpstreamSpan(Tracing::Span&, const Http::ResponseHeaderMap*) override {} private: - std::string conn_id_; - std::string msg_id_; + std::string conn_id_; + std::string msg_id_; }; } // namespace KusciaPoller diff --git a/kuscia/source/filters/http/kuscia_poller/config.cc b/kuscia/source/filters/http/kuscia_poller/config.cc index 8b98fad..36440a0 100755 --- a/kuscia/source/filters/http/kuscia_poller/config.cc +++ b/kuscia/source/filters/http/kuscia_poller/config.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "kuscia/source/filters/http/kuscia_poller/config.h" #include "envoy/registry/registry.h" @@ -26,12 +25,12 @@ namespace KusciaPoller { Http::FilterFactoryCb PollerConfigFactory::createFilterFactoryFromProtoTyped( const envoy::extensions::filters::http::kuscia_poller::v3::Poller& proto_config, - const std::string&, - Server::Configuration::FactoryContext& context) { + const std::string&, Server::Configuration::FactoryContext& context) { - return [proto_config, &context](Http::FilterChainFactoryCallbacks& callbacks) -> void { - callbacks.addStreamFilter(std::make_shared(proto_config, context.clusterManager(), context.timeSource())); - }; + return [proto_config, &context](Http::FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addStreamFilter(std::make_shared( + proto_config, context.serverFactoryContext().clusterManager(), context.serverFactoryContext().timeSource())); + }; } REGISTER_FACTORY(PollerConfigFactory, Server::Configuration::NamedHttpFilterConfigFactory); diff --git a/kuscia/source/filters/http/kuscia_poller/config.h b/kuscia/source/filters/http/kuscia_poller/config.h index 76c612f..9a0936d 100755 --- a/kuscia/source/filters/http/kuscia_poller/config.h +++ b/kuscia/source/filters/http/kuscia_poller/config.h @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include @@ -27,14 +26,14 @@ namespace Extensions { namespace HttpFilters { namespace KusciaPoller { -class PollerConfigFactory : public Extensions::HttpFilters::Common::FactoryBase { - public: - PollerConfigFactory() : FactoryBase("envoy.filters.http.kuscia_poller") {} +class PollerConfigFactory : public Extensions::HttpFilters::Common::FactoryBase< + envoy::extensions::filters::http::kuscia_poller::v3::Poller> { +public: + PollerConfigFactory() : FactoryBase("envoy.filters.http.kuscia_poller") {} - Http::FilterFactoryCb createFilterFactoryFromProtoTyped( - const envoy::extensions::filters::http::kuscia_poller::v3::Poller&, - const std::string&, - Server::Configuration::FactoryContext&) override; + Http::FilterFactoryCb createFilterFactoryFromProtoTyped( + const envoy::extensions::filters::http::kuscia_poller::v3::Poller&, const std::string&, + Server::Configuration::FactoryContext&) override; }; } // namespace KusciaPoller diff --git a/kuscia/source/filters/http/kuscia_poller/poller_filter.cc b/kuscia/source/filters/http/kuscia_poller/poller_filter.cc index 54d4a6a..56ce880 100644 --- a/kuscia/source/filters/http/kuscia_poller/poller_filter.cc +++ b/kuscia/source/filters/http/kuscia_poller/poller_filter.cc @@ -13,14 +13,14 @@ // limitations under the License. #include "kuscia/source/filters/http/kuscia_poller/poller_filter.h" -#include "kuscia/source/filters/http/kuscia_poller/callbacks.h" -#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" #include "envoy/http/filter.h" #include "envoy/http/header_map.h" +#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" +#include "kuscia/source/filters/http/kuscia_poller/callbacks.h" +#include "poller_filter.h" #include "source/common/http/headers.h" #include "source/common/http/message_impl.h" #include "source/common/http/utility.h" -#include "poller_filter.h" namespace Envoy { namespace Extensions { @@ -28,304 +28,304 @@ namespace HttpFilters { namespace KusciaPoller { bool starts_with(absl::string_view str, absl::string_view prefix) { - return str.size() >= prefix.size() && str.substr(0, prefix.size()) == prefix; + return str.size() >= prefix.size() && str.substr(0, prefix.size()) == prefix; } -// isPollRequest checks if a host string matches the format of the three part domain name "receiver.*.svc and a path string matches the format "/poll*" -bool isPollRequest(absl::string_view host, absl::string_view path, const std::string expected_service_name, std::string &domain_id) { - if (host.size() < 13) { // "receiver..svc" is 13 characters long - return false; - } - - // Split the host into segments based on '.' - std::vector segments; - size_t start = 0; - size_t end = host.find('.'); - while (end != absl::string_view::npos) { - segments.push_back(host.substr(start, end - start)); - start = end + 1; - end = host.find('.', start); - } - segments.push_back(host.substr(start, end)); - - // Check if the host has exactly 3 segments, starts with "receiver", and ends with "svc" - if (!(segments.size() == 3 && segments[0] == expected_service_name && segments[2] == "svc")) { - return false; - } - // Check if the path starts with "/poll" - if (!starts_with(path, "/poll")) { - return false; - } - - domain_id = std::string(segments[1]); - - return true; -} +// isPollRequest checks if a host string matches the format of the three part domain name +// "receiver.*.svc and a path string matches the format "/poll*" +bool isPollRequest(absl::string_view host, absl::string_view path, + const std::string expected_service_name, std::string& domain_id) { + if (host.size() < 13) { // "receiver..svc" is 13 characters long + return false; + } + + // Split the host into segments based on '.' + std::vector segments; + size_t start = 0; + size_t end = host.find('.'); + while (end != absl::string_view::npos) { + segments.push_back(host.substr(start, end - start)); + start = end + 1; + end = host.find('.', start); + } + segments.push_back(host.substr(start, end)); + + // Check if the host has exactly 3 segments, starts with "receiver", and ends with "svc" + if (!(segments.size() == 3 && segments[0] == expected_service_name && segments[2] == "svc")) { + return false; + } + // Check if the path starts with "/poll" + if (!starts_with(path, "/poll")) { + return false; + } -PollerFilter::PollerFilter(const PollerConfigPbConfig &config, Upstream::ClusterManager & cluster_manager, TimeSource& time_source) - : receiver_service_name_(config.receiver_service_name()), req_timeout_(config.request_timeout()), rsp_timeout_(config.response_timeout()), - heartbeat_interval_(config.heartbeat_interval()), cluster_manager_(cluster_manager), time_source_(time_source) -{ - if (receiver_service_name_.size() == 0) { - receiver_service_name_ = "receiver"; - } - if (req_timeout_ <= 0) { - req_timeout_ = 30; - } - if (rsp_timeout_ <= 0) { - rsp_timeout_ = 30; - } - if (heartbeat_interval_ <= 0) { - heartbeat_interval_ = 25; - } + domain_id = std::string(segments[1]); - for (const auto& source_headers : config.append_headers()) { - std::vector> headers; - headers.reserve(source_headers.headers_size()); - for (const auto& entry : source_headers.headers()) { - headers.emplace_back(entry.key(), entry.value()); - } - append_headers_.emplace(source_headers.source(), headers); - } + return true; } -PollerFilter::~PollerFilter() -{ - if (response_timer_) { - response_timer_->disableTimer(); +PollerFilter::PollerFilter(const PollerConfigPbConfig& config, + Upstream::ClusterManager& cluster_manager, TimeSource& time_source) + : receiver_service_name_(config.receiver_service_name()), + req_timeout_(config.request_timeout()), rsp_timeout_(config.response_timeout()), + heartbeat_interval_(config.heartbeat_interval()), cluster_manager_(cluster_manager), + time_source_(time_source) { + if (receiver_service_name_.size() == 0) { + receiver_service_name_ = "receiver"; + } + if (req_timeout_ <= 0) { + req_timeout_ = 30; + } + if (rsp_timeout_ <= 0) { + rsp_timeout_ = 30; + } + if (heartbeat_interval_ <= 0) { + heartbeat_interval_ = 25; + } + + for (const auto& source_headers : config.append_headers()) { + std::vector> headers; + headers.reserve(source_headers.headers_size()); + for (const auto& entry : source_headers.headers()) { + headers.emplace_back(entry.key(), entry.value()); } + append_headers_.emplace(source_headers.source(), headers); + } } -Http::FilterHeadersStatus PollerFilter::decodeHeaders(Http::RequestHeaderMap &headers, bool) -{ - auto host = headers.getHostValue(); - auto path = headers.getPathValue(); - - if (isPollRequest(host, path, receiver_service_name_, peer_domain_)) - { - auto query_params = Http::Utility::parseQueryString(path); - auto svc = query_params.find(KusciaCommon::ServiceParamKey); - - Envoy::SystemTime system_time = time_source_.systemTime(); - std::chrono::seconds seconds_since_epoch = std::chrono::duration_cast( - system_time.time_since_epoch()); - std::string current_timestamp = std::to_string(seconds_since_epoch.count()); - if (svc == query_params.end() && svc->second.empty()) { - conn_id_ = "unknown:" + peer_domain_ + ":" + current_timestamp; - } else { - conn_id_ = std::string(svc->second) + ":" + peer_domain_ + ":" + current_timestamp; - } - peer_receiver_host_ = receiver_service_name_ + "." + peer_domain_ + ".svc"; - forward_response_ = true; - ENVOY_LOG(info, "[{}] Poller begin to forward response, host: {}, path: {}, peer_domain_: {}", conn_id_, host, path, peer_domain_); - } - - return Http::FilterHeadersStatus::Continue; +PollerFilter::~PollerFilter() { + if (response_timer_) { + response_timer_->disableTimer(); + } } -Http::FilterHeadersStatus PollerFilter::encodeHeaders(Http::ResponseHeaderMap &headers, bool) -{ - if (forward_response_) - { - const auto status = headers.getStatusValue(); - ENVOY_LOG(info, "[{}] Poller status: {}", conn_id_, status); - - if (status == "200") - { - encoder_callbacks_->setEncoderBufferLimit(1024 * 1024 * 100); - response_timer_ = encoder_callbacks_->dispatcher().createTimer([this]() -> void { - sendHeartbeat(); - }); - response_timer_->enableTimer(std::chrono::seconds(heartbeat_interval_)); - } else { - forward_response_ = false; - } - } +Http::FilterHeadersStatus PollerFilter::decodeHeaders(Http::RequestHeaderMap& headers, bool) { + auto host = headers.getHostValue(); + auto path = headers.getPathValue(); - return Http::FilterHeadersStatus::Continue; -} + if (isPollRequest(host, path, receiver_service_name_, peer_domain_)) { + auto query_params = Http::Utility::QueryParamsMulti::parseQueryString(path); + auto svc = query_params.getFirstValue(KusciaCommon::ServiceParamKey); -Http::FilterDataStatus PollerFilter::encodeData(Buffer::Instance &data, bool) -{ - if (!forward_response_) - { - return Http::FilterDataStatus::Continue; + Envoy::SystemTime system_time = time_source_.systemTime(); + std::chrono::seconds seconds_since_epoch = + std::chrono::duration_cast(system_time.time_since_epoch()); + std::string current_timestamp = std::to_string(seconds_since_epoch.count()); + if (!svc.has_value()) { + conn_id_ = "unknown:" + peer_domain_ + ":" + current_timestamp; + } else { + conn_id_ = std::string(svc.value()) + ":" + peer_domain_ + ":" + current_timestamp; } + peer_receiver_host_ = receiver_service_name_ + "." + peer_domain_ + ".svc"; + forward_response_ = true; + ENVOY_LOG(info, "[{}] Poller begin to forward response, host: {}, path: {}, peer_domain_: {}", + conn_id_, host, path, peer_domain_); + } - while (attemptToDecodeMessage(data)){} - - return Http::FilterDataStatus::StopIterationNoBuffer; + return Http::FilterHeadersStatus::Continue; } -bool PollerFilter::attemptToDecodeMessage(Buffer::Instance &data) -{ - if (data.length() == 0) { - return false; - } +Http::FilterHeadersStatus PollerFilter::encodeHeaders(Http::ResponseHeaderMap& headers, bool) { + if (forward_response_) { + const auto status = headers.getStatusValue(); + ENVOY_LOG(info, "[{}] Poller status: {}", conn_id_, status); - RequestMessagePb message; - KusciaCommon::DecodeStatus status = decoder_.decode(data, message); - if (status == KusciaCommon::DecodeStatus::Ok) { - std::string errmsg; - int32_t status_code = forwardMessage(message, errmsg); - if (status_code != 200) { - ENVOY_LOG(error, "[{}] Forward message {} to {}{} error: {}", conn_id_, message.id(), message.host(), message.path(), errmsg); - - ResponseMessagePb resp_msg_pb; - resp_msg_pb.set_status_code(status_code); - resp_msg_pb.set_end_stream(true); - std::string errmsg; - if (!replyToReceiver(conn_id_, cluster_manager_, message.id(), peer_receiver_host_, resp_msg_pb, rsp_timeout_, errmsg)) { - ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, message.id(), errmsg); - } - } - return true; - } else if (status == KusciaCommon::DecodeStatus::NeedMoreData) { - ENVOY_LOG(info, "[{}] Decode message need more data", conn_id_); - return false; + if (status == "200") { + encoder_callbacks_->setEncoderBufferLimit(1024 * 1024 * 100); + response_timer_ = + encoder_callbacks_->dispatcher().createTimer([this]() -> void { sendHeartbeat(); }); + response_timer_->enableTimer(std::chrono::seconds(heartbeat_interval_)); } else { - ENVOY_LOG(error, "[{}] Decode message error code: {}", conn_id_, KusciaCommon::decodeStatusString(status)); - encoder_callbacks_->resetStream(); - return false; + forward_response_ = false; } + } - return false; + return Http::FilterHeadersStatus::Continue; } -bool parseKusciaHost(const std::string& host, std::string &service) { - // Split the host into segments based on '.' - std::vector segments; - size_t start = 0; - size_t end = host.find('.'); - while (end != std::string::npos) { - segments.push_back(host.substr(start, end - start)); - start = end + 1; - end = host.find('.', start); - } - segments.push_back(host.substr(start, end)); - - if (segments.size() < 1) { - return false; - } +Http::FilterDataStatus PollerFilter::encodeData(Buffer::Instance& data, bool) { + if (!forward_response_) { + return Http::FilterDataStatus::Continue; + } - service = segments[0]; + while (attemptToDecodeMessage(data)) { + } - return true; + return Http::FilterDataStatus::StopIterationNoBuffer; } -int32_t PollerFilter::forwardMessage(const RequestMessagePb &message, std::string& errmsg) -{ - std::string host = message.host(); - - ENVOY_LOG(info, "[{}] Forward message {} to {}{}, method: {}", conn_id_, message.id(), message.host(), message.path(), message.method()); - - std::string service_name; - if (!parseKusciaHost(host, service_name)) { - errmsg = "parse kuscia host " + host + " error"; - return 500; - } - std::string cluster_name = "service-" + service_name; - // TODO check service_name - if (message.path() == "/handshake") { - cluster_name = "handshake-cluster"; +bool PollerFilter::attemptToDecodeMessage(Buffer::Instance& data) { + if (data.length() == 0) { + return false; + } + + RequestMessagePb message; + KusciaCommon::DecodeStatus status = decoder_.decode(data, message); + if (status == KusciaCommon::DecodeStatus::Ok) { + std::string errmsg; + int32_t status_code = forwardMessage(message, errmsg); + if (status_code != 200) { + ENVOY_LOG(error, "[{}] Forward message {} to {}{} error: {}", conn_id_, message.id(), + message.host(), message.path(), errmsg); + + ResponseMessagePb resp_msg_pb; + resp_msg_pb.set_status_code(status_code); + resp_msg_pb.set_end_stream(true); + std::string errmsg; + if (!replyToReceiver(conn_id_, message.host(), cluster_manager_, message.id(), + peer_receiver_host_, resp_msg_pb, rsp_timeout_, errmsg)) { + ENVOY_LOG(error, "[{}] Reply to receiver error: {}, message id: {}", conn_id_, + message.id(), errmsg); + } } + return true; + } else if (status == KusciaCommon::DecodeStatus::NeedMoreData) { + ENVOY_LOG(info, "[{}] Decode message need more data", conn_id_); + return false; + } else { + ENVOY_LOG(error, "[{}] Decode message error code: {}", conn_id_, + KusciaCommon::decodeStatusString(status)); + encoder_callbacks_->resetStream(); + return false; + } - // Ensure the existence of the target cluster - // TODO Not considering domain transit - Upstream::ThreadLocalCluster* cluster = cluster_manager_.getThreadLocalCluster(cluster_name); - if (cluster == nullptr) { - errmsg = "cluster " + cluster_name + " not found"; - return 404; - } + return false; +} - // Get asynchronous HTTP client - Http::AsyncClient& client = cluster->httpAsyncClient(); +bool parseKusciaHost(const std::string& host, std::string& service) { + // Split the host into segments based on '.' + std::vector segments; + size_t start = 0; + size_t end = host.find('.'); + while (end != std::string::npos) { + segments.push_back(host.substr(start, end - start)); + start = end + 1; + end = host.find('.', start); + } + segments.push_back(host.substr(start, end)); + + if (segments.size() < 1) { + return false; + } - // Construct request message - Http::RequestMessagePtr req_msg(new Http::RequestMessageImpl()); + service = segments[0]; - for (const auto& header : message.headers()) { - // Add each header to the request's headers - req_msg->headers().addCopy( - Envoy::Http::LowerCaseString(header.first), - header.second - ); - } + return true; +} - req_msg->headers().setPath(message.path()); - req_msg->headers().setHost(message.host()); - req_msg->headers().setMethod(message.method()); - req_msg->body().add(message.body()); +int32_t PollerFilter::forwardMessage(const RequestMessagePb& message, std::string& errmsg) { + std::string host = message.host(); + + ENVOY_LOG(info, "[{}] Forward message {} to {}{}, method: {}", conn_id_, message.id(), + message.host(), message.path(), message.method()); + + std::string service_name; + if (!parseKusciaHost(host, service_name)) { + errmsg = "parse kuscia host " + host + " error"; + return 500; + } + std::string cluster_name = "service-" + service_name; + // TODO check service_name + if (message.path() == "/handshake") { + cluster_name = "handshake-cluster"; + } + + // Ensure the existence of the target cluster + // TODO Not considering domain transit + Upstream::ThreadLocalCluster* cluster = cluster_manager_.getThreadLocalCluster(cluster_name); + if (cluster == nullptr) { + errmsg = "cluster " + cluster_name + " not found"; + return 404; + } + + // Get asynchronous HTTP client + Http::AsyncClient& client = cluster->httpAsyncClient(); + + // Construct request message + Http::RequestMessagePtr req_msg(new Http::RequestMessageImpl()); + + for (const auto& header : message.headers()) { + // Add each header to the request's headers + req_msg->headers().addCopy(Envoy::Http::LowerCaseString(header.first), header.second); + } + + req_msg->headers().setPath(message.path()); + req_msg->headers().setHost(message.host()); + req_msg->headers().setMethod(message.method()); + req_msg->body().add(message.body()); + + if (service_name == "apiserver") { + return forwardToApiserver(client, host, req_msg, message.id(), errmsg); + } else { + return forwardToApplication(client, host, req_msg, message.id(), errmsg); + } +} - if (service_name == "apiserver") { - return forwardToApiserver(client, req_msg, message.id(), errmsg); - } else { - return forwardToApplication(client, req_msg, message.id(), errmsg); - } +int32_t PollerFilter::forwardToApplication(Http::AsyncClient& client, const std::string& req_host, + Http::RequestMessagePtr& req_msg, + const std::string& msg_id, std::string& errmsg) { + // Send asynchronous requests + ApplicationCallbacks* callbacks = new ApplicationCallbacks( + conn_id_, req_host, cluster_manager_, msg_id, peer_receiver_host_, rsp_timeout_); + Envoy::Http::AsyncClient::RequestOptions options; + options.setTimeout(std::chrono::milliseconds(req_timeout_ * 1000)); + Http::AsyncClient::Request* request = client.send(std::move(req_msg), *callbacks, options); + if (request == nullptr) { + delete callbacks; + callbacks = nullptr; + errmsg = "can't create request"; + return 500; + } + + return 200; } -int32_t PollerFilter::forwardToApplication(Http::AsyncClient& client, Http::RequestMessagePtr& req_msg, const std::string& msg_id, std::string& errmsg) -{ - // Send asynchronous requests - ApplicationCallbacks* callbacks = new ApplicationCallbacks(conn_id_, cluster_manager_, msg_id, peer_receiver_host_, rsp_timeout_); - Envoy::Http::AsyncClient::RequestOptions options; - options.setTimeout(std::chrono::milliseconds(req_timeout_ * 1000)); - Http::AsyncClient::Request* request = client.send(std::move(req_msg), *callbacks, options); - if (request == nullptr) { - delete callbacks; - callbacks = nullptr; - errmsg = "can't create request"; - return 500; - } +int32_t PollerFilter::forwardToApiserver(Http::AsyncClient& client, const std::string& req_host, + Http::RequestMessagePtr& req_msg, + const std::string& msg_id, std::string& errmsg) { + // TODO Not considering domain transit + appendHeaders(req_msg->headers()); - return 200; -} + // Send asynchronous requests + ApiserverCallbacks* callbacks = new ApiserverCallbacks( + conn_id_, req_host, cluster_manager_, msg_id, peer_receiver_host_, rsp_timeout_); + Envoy::Http::AsyncClient::StreamOptions options; + options.setTimeout(std::chrono::milliseconds(6 * 60 * 1000)); -int32_t PollerFilter::forwardToApiserver(Http::AsyncClient& client, Http::RequestMessagePtr& req_msg, const std::string& msg_id, std::string& errmsg) -{ - // TODO Not considering domain transit - appendHeaders(req_msg->headers()); - - // Send asynchronous requests - ApiserverCallbacks* callbacks = new ApiserverCallbacks(conn_id_, cluster_manager_, msg_id, peer_receiver_host_, rsp_timeout_); - Envoy::Http::AsyncClient::StreamOptions options; - options.setTimeout(std::chrono::milliseconds(6 * 60 * 1000)); - - // TODO How to know if the client connection is released - Envoy::Http::AsyncClient::Stream* stream = client.start(*callbacks, options); - if (!stream) { - delete callbacks; - callbacks = nullptr; - errmsg = "can't create stream request"; - return 500; - } + // TODO How to know if the client connection is released + Envoy::Http::AsyncClient::Stream* stream = client.start(*callbacks, options); + if (!stream) { + delete callbacks; + callbacks = nullptr; + errmsg = "can't create stream request"; + return 500; + } - stream->sendHeaders(req_msg->headers(), false); + stream->sendHeaders(req_msg->headers(), false); - stream->sendData(req_msg->body(), true); + stream->sendData(req_msg->body(), true); - callbacks->saveRequestMessage(std::move(req_msg)); + callbacks->saveRequestMessage(std::move(req_msg)); - return 200; + return 200; } void PollerFilter::appendHeaders(Http::RequestHeaderMap& headers) { - auto iter = append_headers_.find(peer_domain_); - if (iter != append_headers_.end()) { - for (const auto& entry : iter->second) { - headers.addCopy(Http::LowerCaseString(entry.first), entry.second); - } + auto iter = append_headers_.find(peer_domain_); + if (iter != append_headers_.end()) { + for (const auto& entry : iter->second) { + headers.addCopy(Http::LowerCaseString(entry.first), entry.second); } + } } void PollerFilter::sendHeartbeat() { - Buffer::OwnedImpl hello_data("hello"); - encoder_callbacks_->injectEncodedDataToFilterChain(hello_data, false); + Buffer::OwnedImpl hello_data("hello"); + encoder_callbacks_->injectEncodedDataToFilterChain(hello_data, false); - response_timer_->enableTimer(std::chrono::seconds(heartbeat_interval_)); + response_timer_->enableTimer(std::chrono::seconds(heartbeat_interval_)); } - } // namespace KusciaPoller } // namespace HttpFilters } // namespace Extensions diff --git a/kuscia/source/filters/http/kuscia_poller/poller_filter.h b/kuscia/source/filters/http/kuscia_poller/poller_filter.h index 9c54049..0c8563f 100644 --- a/kuscia/source/filters/http/kuscia_poller/poller_filter.h +++ b/kuscia/source/filters/http/kuscia_poller/poller_filter.h @@ -14,15 +14,15 @@ #pragma once +#include "envoy/common/time.h" +#include "envoy/event/timer.h" +#include "envoy/http/filter.h" +#include "envoy/upstream/cluster_manager.h" #include "kuscia/source/filters/http/kuscia_common/coder.h" +#include "kuscia/source/filters/http/kuscia_poller/callbacks.h" #include "kuscia/source/filters/http/kuscia_poller/common.h" #include "source/common/common/logger.h" -#include "envoy/http/filter.h" -#include "envoy/upstream/cluster_manager.h" #include "source/extensions/filters/http/common/pass_through_filter.h" -#include "kuscia/source/filters/http/kuscia_poller/callbacks.h" -#include "envoy/event/timer.h" -#include "envoy/common/time.h" #include #include @@ -33,47 +33,53 @@ namespace KusciaPoller { class PollerFilter : public Http::PassThroughFilter, public Logger::Loggable { public: - explicit PollerFilter(const PollerConfigPbConfig& config, Upstream::ClusterManager& cluster_manager, TimeSource& time_source); - ~PollerFilter(); + explicit PollerFilter(const PollerConfigPbConfig& config, + Upstream::ClusterManager& cluster_manager, TimeSource& time_source); + ~PollerFilter(); - Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, bool) override; + Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, bool) override; - // Http::FilterDataStatus decodeData(Buffer::Instance& data, bool end_stream) override; + // Http::FilterDataStatus decodeData(Buffer::Instance& data, bool end_stream) override; - Http::FilterHeadersStatus encodeHeaders(Http::ResponseHeaderMap& headers, bool end_stream) override; + Http::FilterHeadersStatus encodeHeaders(Http::ResponseHeaderMap& headers, + bool end_stream) override; - Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override; + Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override; private: - bool attemptToDecodeMessage(Buffer::Instance& data); - - int32_t forwardMessage(const RequestMessagePb &message, std::string& errmsg); - int32_t forwardToApplication(Http::AsyncClient& client, Http::RequestMessagePtr& req_msg, const std::string& msg_id, std::string& errmsg); - int32_t forwardToApiserver(Http::AsyncClient& client, Http::RequestMessagePtr& req_msg, const std::string& msg_id, std::string& errmsg); - - void appendHeaders(Http::RequestHeaderMap& headers); - void sendHeartbeat(); - - bool forward_response_{false}; - std::string conn_id_; - std::string peer_domain_; - std::string receiver_service_name_; - std::string peer_receiver_host_; - int req_timeout_; - int rsp_timeout_; - int heartbeat_interval_; - Upstream::ClusterManager& cluster_manager_; - KusciaCommon::Decoder decoder_; - - std::map>, std::less<>> append_headers_; - - Http::RequestHeaderMapPtr headers_; - - Event::TimerPtr response_timer_; - TimeSource& time_source_; + bool attemptToDecodeMessage(Buffer::Instance& data); + + int32_t forwardMessage(const RequestMessagePb& message, std::string& errmsg); + int32_t forwardToApplication(Http::AsyncClient& client, const std::string& req_host, + Http::RequestMessagePtr& req_msg, const std::string& msg_id, + std::string& errmsg); + int32_t forwardToApiserver(Http::AsyncClient& client, const std::string& req_host, + Http::RequestMessagePtr& req_msg, const std::string& msg_id, + std::string& errmsg); + + void appendHeaders(Http::RequestHeaderMap& headers); + void sendHeartbeat(); + + bool forward_response_{false}; + std::string conn_id_; + std::string peer_domain_; + std::string receiver_service_name_; + std::string peer_receiver_host_; + int req_timeout_; + int rsp_timeout_; + int heartbeat_interval_; + Upstream::ClusterManager& cluster_manager_; + KusciaCommon::Decoder decoder_; + + std::map>, std::less<>> + append_headers_; + + Http::RequestHeaderMapPtr headers_; + + Event::TimerPtr response_timer_; + TimeSource& time_source_; }; - } // namespace KusciaPoller } // namespace HttpFilters } // namespace Extensions diff --git a/kuscia/source/filters/http/kuscia_receiver/conn.h b/kuscia/source/filters/http/kuscia_receiver/conn.h index 6e7dedb..0f70663 100644 --- a/kuscia/source/filters/http/kuscia_receiver/conn.h +++ b/kuscia/source/filters/http/kuscia_receiver/conn.h @@ -155,7 +155,7 @@ class TcpConn : public Logger::Loggable { } BufferPtr buffer; - if (data.body().size() > 0) { + if (data.body().size() > 0 || !headers) { buffer = std::make_shared(); buffer->add(data.body()); } @@ -173,10 +173,11 @@ class TcpConn : public Logger::Loggable { } }); - ENVOY_LOG( - trace, - "[TcpConn] [C{}][S{}] write response, status_code {} end_stream {} is_chunked {} index {}", - conn_id_, stream_id_, status_code, end_stream, is_chunked, index); + ENVOY_LOG(trace, + "[TcpConn] [C{}][S{}] write response, status_code {} end_stream {} is_chunked {} " + "index {} body size {}", + conn_id_, stream_id_, status_code, end_stream, is_chunked, index, + data.body().size()); return end_stream; } diff --git a/kuscia/source/filters/http/kuscia_receiver/receiver_filter.cc b/kuscia/source/filters/http/kuscia_receiver/receiver_filter.cc index e107ebf..8d42715 100644 --- a/kuscia/source/filters/http/kuscia_receiver/receiver_filter.cc +++ b/kuscia/source/filters/http/kuscia_receiver/receiver_filter.cc @@ -49,9 +49,12 @@ static void setReqHeader(RequestMessagePb& pb, Http::RequestHeaderMap& headers) pb.set_path(std::string(headers.getPathValue())); pb.set_method(std::string(headers.getMethodValue())); auto hs = pb.mutable_headers(); - headers.forEach([&](absl::string_view key, absl::string_view value) -> bool { - (*hs)[std::string(key)] = std::string(value); - return true; + + headers.iterate([&hs](const Http::HeaderEntry& e) -> Http::HeaderMap::Iterate { + auto key = std::string(e.key().getStringView()); + auto value = std::string(e.value().getStringView()); + (*hs)[key] = value; + return Envoy::Http::HeaderMap::Iterate::Continue; }); } @@ -70,6 +73,14 @@ Http::FilterHeadersStatus ReceiverFilter::decodeHeaders(Http::RequestHeaderMap& if (decoder_callbacks_->streamInfo().protocol() != Http::Protocol::Http11) { return Http::FilterHeadersStatus::Continue; } + if (isPassthroughTraffic(headers)) { + ENVOY_LOG(info, "host {} path {} method {} is passthrough traffic", headers.getHostValue(), + headers.getPathValue(), headers.getMethodValue()); + return Http::FilterHeadersStatus::Continue; + } else { + ENVOY_LOG(info, "host {} path {} method {} is kuscia traffic", headers.getHostValue(), + headers.getPathValue(), headers.getMethodValue()); + } if (isPollRequest(headers) || isForwardRequest(headers) || isForwardResponse(headers)) { if (event_type_ == ReceiverEventType::RECEIVER_EVENT_TYPE_DATA_SEND) { setReqHeader(request_pb_, headers); @@ -124,37 +135,43 @@ void ReceiverFilter::onDestroy() { } } +bool ReceiverFilter::isPassthroughTraffic(Http::RequestHeaderMap& headers) { + return headers.get(KusciaCommon::HeaderTransitFlag).empty(); +} + // poll request check // receiver.${peer}.svc/poll?timeout=xxx&service=xxx bool ReceiverFilter::isPollRequest(Http::RequestHeaderMap& headers) { - auto path = headers.getPathValue(); auto host = headers.getHostValue(); - auto source = headers.getByKey(KusciaCommon::HeaderKeyOriginSource).value_or(nullptr); - if (source == nullptr || host == nullptr || path == nullptr) { + auto path = headers.getPathValue(); + auto sourceValue = headers.get(KusciaCommon::HeaderKeyOriginSource); + if (host.empty() || path.empty() || sourceValue.empty() || sourceValue[0]->value().empty()) { return false; } + absl::string_view source = sourceValue[0]->value().getStringView(); + std::string group; if (!re2::RE2::PartialMatch(std::string(host), KusciaCommon::PollHostPattern, &group) || !absl::StartsWith(path, KusciaCommon::PollPathPrefix) || group != config_->selfNamespace()) { return false; } - auto query_params = Http::Utility::parseQueryString(path); - auto svc = query_params.find(KusciaCommon::ServiceParamKey); - if (svc == query_params.end() && svc->second.empty()) { + auto query_params = Http::Utility::QueryParamsMulti::parseQueryString(path); + auto svc = query_params.getFirstValue(KusciaCommon::ServiceParamKey); + if (!svc.has_value()) { return false; } event_type_ = ReceiverEventType::RECEIVER_EVENT_TYPE_CONNECT; rule_ = std::make_shared(); rule_->set_source(group); rule_->set_destination(std::string(source)); - rule_->set_service(std::string(svc->second)); + rule_->set_service(svc.value()); conn_uuid_ = random_.uuid(); - auto timeout = query_params.find(KusciaCommon::TimeoutParamKey); - if (timeout != query_params.end() && !timeout->second.empty()) { - timeout_sec_ = timeout2sec(timeout->second); + auto timeout = query_params.getFirstValue(KusciaCommon::TimeoutParamKey); + if (timeout.has_value()) { + timeout_sec_ = timeout2sec(timeout.value()); ENVOY_LOG(info, "[ReceiverFilter] poll request from {} to {} service {} timeout {}", group, - source, svc->second, timeout->second); + source, svc.value(), timeout.value()); } return true; } @@ -162,15 +179,25 @@ bool ReceiverFilter::isPollRequest(Http::RequestHeaderMap& headers) { // dst = ${svc}.dest-namespace.svc // src = src-namespace bool ReceiverFilter::isForwardRequest(Http::RequestHeaderMap& headers) { - auto source = headers.getByKey(KusciaCommon::HeaderKeyOriginSource).value_or(nullptr); - auto host = headers.getHostValue(); + auto sourceHeader = headers.get(KusciaCommon::HeaderKeyOriginSource); + if (sourceHeader.empty()) { + return false; + } + absl::string_view source = sourceHeader[0]->value().getStringView(); + + absl::string_view host; + auto hostValue = headers.getHostValue(); // rewrite bool rewrite = false; - if (absl::StartsWith(host, KusciaCommon::InternalClusterHost)) { - host = headers.getByKey(KusciaCommon::HeaderKeyKusciaHost).value_or(nullptr); + if (absl::StartsWith(hostValue, KusciaCommon::InternalClusterHost)) { + auto kusciaHost = headers.get(KusciaCommon::HeaderKeyKusciaHost); + if (!kusciaHost.empty()) { + host = kusciaHost[0]->value().getStringView(); + } rewrite = true; } - if (source != config_->selfNamespace() || host == nullptr) { + + if (std::string(source) != config_->selfNamespace() || host.empty()) { return false; } std::vector fields = absl::StrSplit(host, "."); @@ -196,7 +223,7 @@ bool ReceiverFilter::isForwardRequest(Http::RequestHeaderMap& headers) { bool ReceiverFilter::isForwardResponse(Http::RequestHeaderMap& headers) { auto path = headers.getPathValue(); auto host = headers.getHostValue(); - if (!absl::StartsWith(path, KusciaCommon::ReplyPathPrefix) || host == nullptr) { + if (!absl::StartsWith(path, KusciaCommon::ReplyPathPrefix) || host.empty()) { return false; } std::vector fields = absl::StrSplit(host, "."); @@ -206,13 +233,13 @@ bool ReceiverFilter::isForwardResponse(Http::RequestHeaderMap& headers) { if (fields[1] != config_->selfNamespace()) { return false; } - auto query_params = Http::Utility::parseQueryString(path); - auto request_id = query_params.find(KusciaCommon::RequestIdParamKey); - if (request_id == query_params.end() || request_id->second.empty()) { + auto query_params = Http::Utility::QueryParamsMulti::parseQueryString(path); + auto request_id = query_params.getFirstValue(KusciaCommon::RequestIdParamKey); + if (!request_id.has_value()) { return false; } event_type_ = ReceiverEventType::RECEIVER_EVENT_TYPE_DATA_RECV; - request_id_ = request_id->second; + request_id_ = request_id.value(); ENVOY_LOG(trace, "[ReceiverFilter] forward response request_id {}", request_id_); return true; } diff --git a/kuscia/source/filters/http/kuscia_receiver/receiver_filter.h b/kuscia/source/filters/http/kuscia_receiver/receiver_filter.h index 2b104d4..7cdd260 100644 --- a/kuscia/source/filters/http/kuscia_receiver/receiver_filter.h +++ b/kuscia/source/filters/http/kuscia_receiver/receiver_filter.h @@ -75,6 +75,7 @@ class ReceiverFilter : public Http::PassThroughFilter, Logger::Loggable(proto_config); - return [config](Http::FilterChainFactoryCallbacks & callbacks) -> void { - callbacks.addStreamDecoderFilter(std::make_shared(config)); - }; + const std::string&, Server::Configuration::FactoryContext&) { + TokenAuthConfigSharedPtr config = std::make_shared(proto_config); + return [config](Http::FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addStreamDecoderFilter(std::make_shared(config)); + }; } -Router::RouteSpecificFilterConfigConstSharedPtr TokenAuthConfigFactory::createRouteSpecificFilterConfigTyped( +Router::RouteSpecificFilterConfigConstSharedPtr +TokenAuthConfigFactory::createRouteSpecificFilterConfigTyped( const envoy::extensions::filters::http::kuscia_token_auth::v3::FilterConfigPerRoute& - proto_config, + proto_config, Server::Configuration::ServerFactoryContext&, ProtobufMessage::ValidationVisitor&) { - return std::make_shared(proto_config); + return std::make_shared(proto_config); } -REGISTER_FACTORY(TokenAuthConfigFactory, - Server::Configuration::NamedHttpFilterConfigFactory); +REGISTER_FACTORY(TokenAuthConfigFactory, Server::Configuration::NamedHttpFilterConfigFactory); } // namespace KusciaTokenAuth } // namespace HttpFilters diff --git a/kuscia/source/filters/http/kuscia_token_auth/config.h b/kuscia/source/filters/http/kuscia_token_auth/config.h index 01e3793..de3c861 100755 --- a/kuscia/source/filters/http/kuscia_token_auth/config.h +++ b/kuscia/source/filters/http/kuscia_token_auth/config.h @@ -1,18 +1,17 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include @@ -27,23 +26,23 @@ namespace Extensions { namespace HttpFilters { namespace KusciaTokenAuth { -class TokenAuthConfigFactory : public Extensions::HttpFilters::Common::FactoryBase < - envoy::extensions::filters::http::kuscia_token_auth::v3::TokenAuth, - envoy::extensions::filters::http::kuscia_token_auth::v3::FilterConfigPerRoute > { - public: - TokenAuthConfigFactory() : FactoryBase("envoy.filters.http.kuscia_token_auth") {} - - Http::FilterFactoryCb createFilterFactoryFromProtoTyped( - const envoy::extensions::filters::http::kuscia_token_auth::v3::TokenAuth&, - const std::string&, - Server::Configuration::FactoryContext&) override; - - private: - Router::RouteSpecificFilterConfigConstSharedPtr createRouteSpecificFilterConfigTyped( - const envoy::extensions::filters::http::kuscia_token_auth::v3::FilterConfigPerRoute& - proto_config, - Server::Configuration::ServerFactoryContext& context, - ProtobufMessage::ValidationVisitor& validator) override; +class TokenAuthConfigFactory + : public Extensions::HttpFilters::Common::FactoryBase< + envoy::extensions::filters::http::kuscia_token_auth::v3::TokenAuth, + envoy::extensions::filters::http::kuscia_token_auth::v3::FilterConfigPerRoute> { +public: + TokenAuthConfigFactory() : FactoryBase("envoy.filters.http.kuscia_token_auth") {} + + Http::FilterFactoryCb createFilterFactoryFromProtoTyped( + const envoy::extensions::filters::http::kuscia_token_auth::v3::TokenAuth&, + const std::string&, Server::Configuration::FactoryContext&) override; + +private: + Router::RouteSpecificFilterConfigConstSharedPtr createRouteSpecificFilterConfigTyped( + const envoy::extensions::filters::http::kuscia_token_auth::v3::FilterConfigPerRoute& + proto_config, + Server::Configuration::ServerFactoryContext& context, + ProtobufMessage::ValidationVisitor& validator) override; }; } // namespace KusciaTokenAuth diff --git a/kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.cc b/kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.cc index 8fa6f17..a1f4a01 100755 --- a/kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.cc +++ b/kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.h" #include "source/common/common/empty_string.h" @@ -29,60 +28,65 @@ constexpr absl::string_view UnauthorizedBodyMessage = "unauthorized."; using KusciaHeader = Envoy::Extensions::HttpFilters::KusciaCommon::KusciaHeader; -Http::FilterHeadersStatus TokenAuthFilter::decodeHeaders(Http::RequestHeaderMap& headers, - bool) { - // Disable filter per route config if applies - if (decoder_callbacks_->route() != nullptr) { - const auto* per_route_config = - Http::Utility::resolveMostSpecificPerFilterConfig(decoder_callbacks_); - if (per_route_config != nullptr && per_route_config->disabled()) { - return Http::FilterHeadersStatus::Continue; - } +Http::FilterHeadersStatus TokenAuthFilter::decodeHeaders(Http::RequestHeaderMap& headers, bool) { + // Disable filter per route config if applies + if (decoder_callbacks_->route() != nullptr) { + const auto* per_route_config = + Http::Utility::resolveMostSpecificPerFilterConfig( + decoder_callbacks_); + if (per_route_config != nullptr && per_route_config->disabled()) { + return Http::FilterHeadersStatus::Continue; } + } - auto source = KusciaHeader::getSource(headers).value_or(""); - auto token = headers.getByKey(KusciaCommon::HeaderKeyKusciaToken).value_or(""); - bool is_valid = config_->validateSource(source, token); - if (!is_valid) { - ENVOY_LOG(warn, "Check Kuscia Source Token fail, {}: {}, {}: {}", - KusciaCommon::HeaderKeyKusciaSource, source, - KusciaCommon::HeaderKeyKusciaToken, token); - sendUnauthorizedResponse(); - return Http::FilterHeadersStatus::StopIteration; - } + auto source = KusciaHeader::getSource(headers).value_or(""); + + absl::string_view token; + auto value = headers.get(KusciaCommon::HeaderKeyKusciaToken); + if (!value.empty()) { + token = value[0]->value().getStringView(); + } + bool is_valid = config_->validateSource(source, token); + if (!is_valid) { + ENVOY_LOG(warn, "Check Kuscia Source Token fail, {}: {}, {}: {}", + KusciaCommon::HeaderKeyKusciaSource, source, KusciaCommon::HeaderKeyKusciaToken, + token); + sendUnauthorizedResponse(); + return Http::FilterHeadersStatus::StopIteration; + } - return Http::FilterHeadersStatus::Continue; + return Http::FilterHeadersStatus::Continue; } void TokenAuthFilter::sendUnauthorizedResponse() { - decoder_callbacks_->sendLocalReply(Http::Code::Unauthorized, UnauthorizedBodyMessage, nullptr, - absl::nullopt, Envoy::EMPTY_STRING); + decoder_callbacks_->sendLocalReply(Http::Code::Unauthorized, UnauthorizedBodyMessage, nullptr, + absl::nullopt, Envoy::EMPTY_STRING); } TokenAuthConfig::TokenAuthConfig(const TokenAuthPbConfig& config) { - for (const auto& source_token : config.source_token_list()) { - std::vector tokens; - tokens.reserve(source_token.tokens_size()); - for (const auto& token : source_token.tokens()) { - tokens.emplace_back(token); - } - source_token_map_.emplace(source_token.source(), tokens); + for (const auto& source_token : config.source_token_list()) { + std::vector tokens; + tokens.reserve(source_token.tokens_size()); + for (const auto& token : source_token.tokens()) { + tokens.emplace_back(token); } + source_token_map_.emplace(source_token.source(), tokens); + } } bool TokenAuthConfig::validateSource(absl::string_view source, absl::string_view token) const { - static const std::string NoopToken = "noop"; + static const std::string NoopToken = "noop"; - auto iter = source_token_map_.find(source); - if (iter == source_token_map_.end()) { - return false; - } - for (const auto& disired_token : iter->second) { - if (token == disired_token || disired_token == NoopToken) { - return true; - } - } + auto iter = source_token_map_.find(source); + if (iter == source_token_map_.end()) { return false; + } + for (const auto& disired_token : iter->second) { + if (token == disired_token || disired_token == NoopToken) { + return true; + } + } + return false; } } // namespace KusciaTokenAuth diff --git a/kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.h b/kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.h index 0d5e753..6a81dc7 100755 --- a/kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.h +++ b/kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.h @@ -1,18 +1,17 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once #include @@ -34,47 +33,41 @@ using TokenAuthConfigSharedPtr = std::shared_ptr; using TokenAuthPbConfig = envoy::extensions::filters::http::kuscia_token_auth::v3::TokenAuth; class TokenAuthFilter : public Http::PassThroughDecoderFilter, - public Logger::Loggable { - public: - explicit TokenAuthFilter(TokenAuthConfigSharedPtr config) : - config_(std::move(config)) {} + public Logger::Loggable { +public: + explicit TokenAuthFilter(TokenAuthConfigSharedPtr config) : config_(std::move(config)) {} - Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, - bool) override; + Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, bool) override; - private: - void sendUnauthorizedResponse(); +private: + void sendUnauthorizedResponse(); - TokenAuthConfigSharedPtr config_; + TokenAuthConfigSharedPtr config_; }; class TokenAuthConfig { - public: - explicit TokenAuthConfig(const TokenAuthPbConfig& config); +public: + explicit TokenAuthConfig(const TokenAuthPbConfig& config); - bool validateSource(absl::string_view source, absl::string_view token) const; + bool validateSource(absl::string_view source, absl::string_view token) const; - private: - std::map, std::less<>> source_token_map_; +private: + std::map, std::less<>> source_token_map_; }; class FilterConfigPerRoute : public Router::RouteSpecificFilterConfig { - public: - FilterConfigPerRoute( - const envoy::extensions::filters::http::kuscia_token_auth::v3::FilterConfigPerRoute& - config) - : disabled_(config.disabled()) {} - - bool disabled() const { - return disabled_; - } - - private: - bool disabled_; +public: + FilterConfigPerRoute( + const envoy::extensions::filters::http::kuscia_token_auth::v3::FilterConfigPerRoute& config) + : disabled_(config.disabled()) {} + + bool disabled() const { return disabled_; } + +private: + bool disabled_; }; } // namespace KusciaTokenAuth } // namespace HttpFilters } // namespace Extensions } // namespace Envoy - diff --git a/kuscia/test/filters/http/kuscia_crypt/crypt_filter_test.cc b/kuscia/test/filters/http/kuscia_crypt/crypt_filter_test.cc index 5cab6b2..776d458 100755 --- a/kuscia/test/filters/http/kuscia_crypt/crypt_filter_test.cc +++ b/kuscia/test/filters/http/kuscia_crypt/crypt_filter_test.cc @@ -1,29 +1,28 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "test/mocks/http/mocks.h" #include "source/common/stream_info/stream_info_impl.h" +#include "test/mocks/http/mocks.h" -#include "kuscia/source/filters/http/kuscia_crypt/crypt_filter.h" -#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" #include "kuscia/api/filters/http/kuscia_crypt/v3/crypt.pb.h" +#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" +#include "kuscia/source/filters/http/kuscia_crypt/crypt_filter.h" #include "kuscia/test/filters/http/test_common/header_checker.h" @@ -47,193 +46,191 @@ const std::string kEncryptIv(KusciaCommon::HeaderKeyEncryptIv.get()); const std::string kEncryptVersion(KusciaCommon::HeaderKeyEncryptVersion.get()); class CryptFilterTest : public testing::Test { - public: - CryptFilterTest() : filter_(setupConfig()), peer_filter_(peerSetupConfig()) { - filter_.setDecoderFilterCallbacks(decoder_callbacks_); - filter_.setEncoderFilterCallbacks(encoder_callbacks_); - peer_filter_.setDecoderFilterCallbacks(decoder_callbacks_peer_); - peer_filter_.setEncoderFilterCallbacks(encoder_callbacks_peer_); +public: + CryptFilterTest() : filter_(setupConfig()), peer_filter_(peerSetupConfig()) { + filter_.setDecoderFilterCallbacks(decoder_callbacks_); + filter_.setEncoderFilterCallbacks(encoder_callbacks_); + peer_filter_.setDecoderFilterCallbacks(decoder_callbacks_peer_); + peer_filter_.setEncoderFilterCallbacks(encoder_callbacks_peer_); + } + + CryptConfigSharedPtr setupConfig() { + CryptPbConfig proto_config; + proto_config.set_self_namespace("alice"); + + auto encrypt_config = proto_config.mutable_encrypt_rules()->Add(); + encrypt_config->set_source("alice"); + encrypt_config->set_destination("bob"); + encrypt_config->set_secret_key(kFirstKey); + encrypt_config->set_secret_key_version(kFirstKeyVersion); + encrypt_config->set_algorithm("AES"); + + return std::make_shared(proto_config); + } + + CryptConfigSharedPtr peerSetupConfig() { + CryptPbConfig proto_config; + proto_config.set_self_namespace("bob"); + + auto decrypt_config = proto_config.mutable_decrypt_rules()->Add(); + decrypt_config->set_source("alice"); + decrypt_config->set_destination("bob"); + decrypt_config->set_algorithm("AES"); + decrypt_config->set_secret_key(kFirstKey); + decrypt_config->set_secret_key_version(kFirstKeyVersion); + decrypt_config->set_reserve_key(kSecondKey); + decrypt_config->set_reserve_key_version(kSecondKeyVersion); + return std::make_shared(proto_config); + } + + void checkEncryptionParams(bool forward_encrypt, absl::string_view forward_key, + bool reverse_encrypt, absl::string_view reverse_key) { + EXPECT_EQ(forward_encrypt, static_cast(filter_.forward_crypter_)); + if (filter_.forward_crypter_) { + EXPECT_EQ(forward_key, filter_.forward_crypter_->secret_key_); } - - CryptConfigSharedPtr setupConfig() { - CryptPbConfig proto_config; - proto_config.set_self_namespace("alice"); - - auto encrypt_config = proto_config.mutable_encrypt_rules()->Add(); - encrypt_config->set_source("alice"); - encrypt_config->set_destination("bob"); - encrypt_config->set_secret_key(kFirstKey); - encrypt_config->set_secret_key_version(kFirstKeyVersion); - encrypt_config->set_algorithm("AES"); - - return std::make_shared(proto_config); + EXPECT_EQ(reverse_encrypt, static_cast(peer_filter_.reverse_crypter_)); + if (peer_filter_.reverse_crypter_) { + EXPECT_EQ(reverse_key, peer_filter_.reverse_crypter_->secret_key_); } - - CryptConfigSharedPtr peerSetupConfig() { - CryptPbConfig proto_config; - proto_config.set_self_namespace("bob"); - - auto decrypt_config = proto_config.mutable_decrypt_rules()->Add(); - decrypt_config->set_source("alice"); - decrypt_config->set_destination("bob"); - decrypt_config->set_algorithm("AES"); - decrypt_config->set_secret_key(kFirstKey); - decrypt_config->set_secret_key_version(kFirstKeyVersion); - decrypt_config->set_reserve_key(kSecondKey); - decrypt_config->set_reserve_key_version(kSecondKeyVersion); - return std::make_shared(proto_config); - } - - void checkEncryptionParams(bool forward_encrypt, absl::string_view forward_key, bool reverse_encrypt, - absl::string_view reverse_key) { - EXPECT_EQ(forward_encrypt, static_cast(filter_.forward_crypter_)); - if (filter_.forward_crypter_) { - EXPECT_EQ(forward_key, filter_.forward_crypter_->secret_key_); - - } - EXPECT_EQ(reverse_encrypt, static_cast(peer_filter_.reverse_crypter_)); - if (peer_filter_.reverse_crypter_) { - EXPECT_EQ(reverse_key, peer_filter_.reverse_crypter_->secret_key_); - } - } - - CryptFilter filter_; - CryptFilter peer_filter_; - NiceMock decoder_callbacks_; - NiceMock encoder_callbacks_; - NiceMock decoder_callbacks_peer_; - NiceMock encoder_callbacks_peer_; + } + + CryptFilter filter_; + CryptFilter peer_filter_; + NiceMock decoder_callbacks_; + NiceMock encoder_callbacks_; + NiceMock decoder_callbacks_peer_; + NiceMock encoder_callbacks_peer_; }; TEST_F(CryptFilterTest, RequestAndResponse) { - // request - Http::TestRequestHeaderMapImpl request_headers{{kHost, "hello.bob.svc"}}; - - // alice: decodeheader, create forwardcrypter - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); - KusciaHeaderChecker::checkRequestHeaders( - request_headers, - ExpectHeaders{{kEncryptVersion, kFirstKeyVersion}, {kEncryptIv, "", false}}); - // alice: decodedata, encrypt request body - std::string data("something plaintext"); - Envoy::Buffer::OwnedImpl request_body(data); - EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, true)); - EXPECT_NE(data, request_body.toString()); - EXPECT_EQ(data.length(), request_body.length()); - - // bob: decode header, create reverse crypter - request_headers.addCopy(kKusciaHost, "hello.bob.svc"); - request_headers.addCopy(kOriginSource, "alice"); - EXPECT_EQ(Http::FilterHeadersStatus::Continue, peer_filter_.decodeHeaders(request_headers, false)); - checkEncryptionParams(true, kFirstKey, true, kFirstKey); - - // bob: decodedata, decrypt request body - EXPECT_EQ(Http::FilterDataStatus::Continue, peer_filter_.decodeData(request_body, true)); - EXPECT_EQ(data, request_body.toString()); // after decrypt, data is restored - - // response - Http::TestResponseHeaderMapImpl response_headers{}; - response_headers.setStatus("200"); - EXPECT_EQ(Http::FilterHeadersStatus::Continue, peer_filter_.encodeHeaders(response_headers, false)); - KusciaHeaderChecker::checkResponseHeaders( - response_headers, - ExpectHeaders{{kOriginSource, "alice"}, {kEncryptVersion, kFirstKeyVersion}}); - - Envoy::Buffer::OwnedImpl response_body(data); - EXPECT_EQ(Http::FilterDataStatus::Continue, peer_filter_.encodeData(response_body, true)); - EXPECT_NE(data, response_body.toString()); - EXPECT_EQ(data.length(), response_body.length()); - - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.encodeHeaders(response_headers, false)); - checkEncryptionParams(true, kFirstKey, true, kFirstKey); - EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.encodeData(response_body, true)); - EXPECT_EQ(data, response_body.toString()); // after decrypt, data is restored + // request + Http::TestRequestHeaderMapImpl request_headers{{kHost, "hello.bob.svc"}}; + + // alice: decodeheader, create forwardcrypter + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); + KusciaHeaderChecker::checkRequestHeaders( + request_headers, + ExpectHeaders{{kEncryptVersion, kFirstKeyVersion}, {kEncryptIv, "", false}}); + // alice: decodedata, encrypt request body + std::string data("something plaintext"); + Envoy::Buffer::OwnedImpl request_body(data); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, true)); + EXPECT_NE(data, request_body.toString()); + EXPECT_EQ(data.length() + AES_GCM_TAG_LENGTH, request_body.length()); + + // bob: decode header, create reverse crypter + request_headers.addCopy(kKusciaHost, "hello.bob.svc"); + request_headers.addCopy(kOriginSource, "alice"); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + peer_filter_.decodeHeaders(request_headers, false)); + checkEncryptionParams(true, kFirstKey, true, kFirstKey); + + // bob: decodedata, decrypt request body + EXPECT_EQ(Http::FilterDataStatus::Continue, peer_filter_.decodeData(request_body, true)); + EXPECT_EQ(data, request_body.toString()); // after decrypt, data is restored + + // response + Http::TestResponseHeaderMapImpl response_headers{}; + response_headers.setStatus("200"); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + peer_filter_.encodeHeaders(response_headers, false)); + KusciaHeaderChecker::checkResponseHeaders( + response_headers, + ExpectHeaders{{kOriginSource, "alice"}, {kEncryptVersion, kFirstKeyVersion}}); + + Envoy::Buffer::OwnedImpl response_body(data); + EXPECT_EQ(Http::FilterDataStatus::Continue, peer_filter_.encodeData(response_body, true)); + EXPECT_NE(data, response_body.toString()); + EXPECT_EQ(data.length() + AES_GCM_TAG_LENGTH, response_body.length()); + + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.encodeHeaders(response_headers, false)); + checkEncryptionParams(true, kFirstKey, true, kFirstKey); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.encodeData(response_body, true)); + EXPECT_EQ(data, response_body.toString()); // after decrypt, data is restored } TEST_F(CryptFilterTest, BigRequestBody) { - Http::TestRequestHeaderMapImpl request_headers{{kHost, "hello.bob.svc"}}; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); - KusciaHeaderChecker::checkRequestHeaders( - request_headers, - ExpectHeaders{{kEncryptVersion, kFirstKeyVersion}, {kEncryptIv, "", false}}); - - const uint32_t kBodySize = 10 * 4096 + 2000; // just a casual length - char data[kBodySize]; // no need to initialize - Envoy::Buffer::OwnedImpl request_body(data, kBodySize); - Envoy::Buffer::OwnedImpl encrypted_body; - srand(time(NULL)); - - auto remain_length = kBodySize; - auto len = rand() % 8192; - while (remain_length > 8192) { - Envoy::Buffer::OwnedImpl body(request_body.linearize(len), len); - request_body.drain(len); - EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(body, false)); - encrypted_body.add(body); - remain_length -= len; - len = rand() % 8192; - } - - if (remain_length > 0) { - EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, true)); - encrypted_body.add(request_body); - } - EXPECT_NE(Envoy::Buffer::OwnedImpl(data, kBodySize).toString(), encrypted_body.toString()); - EXPECT_EQ(kBodySize, encrypted_body.length()); - - request_headers.addCopy(kKusciaHost, "hello.bob.svc"); - request_headers.addCopy(kOriginSource, "alice"); - EXPECT_EQ(Http::FilterHeadersStatus::Continue, peer_filter_.decodeHeaders(request_headers, false)); - checkEncryptionParams(true, kFirstKey, true, kFirstKey); + Http::TestRequestHeaderMapImpl request_headers{{kHost, "hello.bob.svc"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); + KusciaHeaderChecker::checkRequestHeaders( + request_headers, + ExpectHeaders{{kEncryptVersion, kFirstKeyVersion}, {kEncryptIv, "", false}}); + + const uint32_t kBodySize = 10 * 4096 + 2000; // just a casual length + char data[kBodySize]; // no need to initialize + Envoy::Buffer::OwnedImpl request_body(data, kBodySize); + Envoy::Buffer::OwnedImpl encrypted_body; + srand(time(NULL)); + + auto remain_length = kBodySize; + auto len = rand() % 8192; + while (remain_length > 8192) { + Envoy::Buffer::OwnedImpl body(request_body.linearize(len), len); + request_body.drain(len); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(body, false)); + encrypted_body.add(body); + remain_length -= len; + len = rand() % 8192; + } - remain_length = kBodySize; + if (remain_length > 0) { + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, true)); + encrypted_body.add(request_body); + } + EXPECT_NE(Envoy::Buffer::OwnedImpl(data, kBodySize).toString(), encrypted_body.toString()); + EXPECT_EQ(kBodySize + AES_GCM_TAG_LENGTH, encrypted_body.length()); + + request_headers.addCopy(kKusciaHost, "hello.bob.svc"); + request_headers.addCopy(kOriginSource, "alice"); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + peer_filter_.decodeHeaders(request_headers, false)); + checkEncryptionParams(true, kFirstKey, true, kFirstKey); + + remain_length = kBodySize; + len = rand() % 8192; + Envoy::Buffer::OwnedImpl decrypted_body; + while (remain_length > 8192) { + Envoy::Buffer::OwnedImpl body(encrypted_body.linearize(len), len); + encrypted_body.drain(len); + EXPECT_EQ(Http::FilterDataStatus::Continue, peer_filter_.decodeData(body, false)); + decrypted_body.add(body); + remain_length -= len; len = rand() % 8192; - Envoy::Buffer::OwnedImpl decrypted_body; - while (remain_length > 8192) { - Envoy::Buffer::OwnedImpl body(encrypted_body.linearize(len), len); - encrypted_body.drain(len); - EXPECT_EQ(Http::FilterDataStatus::Continue, peer_filter_.decodeData(body, false)); - decrypted_body.add(body); - remain_length -= len; - len = rand() % 8192; - } - if (remain_length > 0) { - EXPECT_EQ(Http::FilterDataStatus::Continue, peer_filter_.decodeData(encrypted_body, true)); - decrypted_body.add(encrypted_body); - } - EXPECT_EQ(Envoy::Buffer::OwnedImpl(data, kBodySize).toString(), - decrypted_body.toString()); // after decrypt, data is restored + } + if (remain_length > 0) { + EXPECT_EQ(Http::FilterDataStatus::Continue, peer_filter_.decodeData(encrypted_body, true)); + decrypted_body.add(encrypted_body); + } + EXPECT_EQ(Envoy::Buffer::OwnedImpl(data, kBodySize).toString(), + decrypted_body.toString()); // after decrypt, data is restored } TEST_F(CryptFilterTest, NoEncrypt) { - Http::TestRequestHeaderMapImpl request_headers{{kHost, "hello.joke.svc"}}; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); - KusciaHeaderChecker::checkRequestHeaders(request_headers, - ExpectHeaders{{kEncryptVersion, ""}}); + Http::TestRequestHeaderMapImpl request_headers{{kHost, "hello.joke.svc"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); + KusciaHeaderChecker::checkRequestHeaders(request_headers, ExpectHeaders{{kEncryptVersion, ""}}); - request_headers.addCopy(kOriginSource, "alice"); - request_headers.addCopy(kKusciaHost, "hello.bob.svc"); - EXPECT_EQ(Http::FilterHeadersStatus::Continue, - peer_filter_.decodeHeaders(request_headers, false)); + request_headers.addCopy(kOriginSource, "alice"); + request_headers.addCopy(kKusciaHost, "hello.bob.svc"); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + peer_filter_.decodeHeaders(request_headers, false)); - checkEncryptionParams(false, "", false, ""); + checkEncryptionParams(false, "", false, ""); } TEST_F(CryptFilterTest, UseReverseEnCryptKey) { - Http::TestRequestHeaderMapImpl request_headers{ - {kKusciaHost, "hello.bob.svc"}, - {kOriginSource, "alice"}, - {kEncryptVersion, kSecondKeyVersion}, - {kEncryptIv, "1"} - }; - - EXPECT_EQ(Http::FilterHeadersStatus::Continue, - peer_filter_.decodeHeaders(request_headers, false)); - checkEncryptionParams(false, "", true, kSecondKey); + Http::TestRequestHeaderMapImpl request_headers{{kKusciaHost, "hello.bob.svc"}, + {kOriginSource, "alice"}, + {kEncryptVersion, kSecondKeyVersion}, + {kEncryptIv, "1"}}; + + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + peer_filter_.decodeHeaders(request_headers, false)); + checkEncryptionParams(false, "", true, kSecondKey); } -} // namespace KusciaBasic +} // namespace KusciaCrypt } // namespace HttpFilters } // namespace Extensions } // namespace Envoy - diff --git a/kuscia/test/filters/http/kuscia_gress/gress_filter_test.cc b/kuscia/test/filters/http/kuscia_gress/gress_filter_test.cc index 06546aa..67e2fdd 100755 --- a/kuscia/test/filters/http/kuscia_gress/gress_filter_test.cc +++ b/kuscia/test/filters/http/kuscia_gress/gress_filter_test.cc @@ -12,18 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "test/mocks/http/mocks.h" #include "source/common/stream_info/stream_info_impl.h" +#include "test/mocks/http/mocks.h" -#include "kuscia/source/filters/http/kuscia_gress/gress_filter.h" -#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" #include "kuscia/api/filters/http/kuscia_gress/v3/gress.pb.h" +#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" +#include "kuscia/source/filters/http/kuscia_gress/gress_filter.h" #include "kuscia/test/filters/http/test_common/header_checker.h" @@ -40,96 +39,94 @@ namespace { using namespace Envoy::Extensions::HttpFilters::KusciaTest; class GressFilterTest : public testing::Test { - public: - GressFilterTest() : filter_(setupConfig()), config_(setupConfig()) { - filter_.setDecoderFilterCallbacks(decoder_callbacks_); - } - - GressFilterConfigSharedPtr setupConfig() { - GressPbConfig proto_config; - proto_config.set_instance("foo"); - proto_config.set_self_namespace("alice"); - proto_config.set_add_origin_source(true); - proto_config.set_max_logging_body_size_per_reqeuest(5); - auto rh = proto_config.add_rewrite_host_config(); - rh->set_header("kuscia-Host"); - rh->set_rewrite_policy(RewriteHost::RewriteHostWithHeader); - return GressFilterConfigSharedPtr(new GressFilterConfig(proto_config)); - } - - void enableRecordBody () { - std::string kRecordBody(KusciaCommon::HeaderKeyRecordBody.get()); - Http::TestRequestHeaderMapImpl headers{{kRecordBody, "true"}}; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, false)); - } - - GressFilter filter_; - GressFilterConfigSharedPtr config_; - NiceMock decoder_callbacks_; +public: + GressFilterTest() : filter_(setupConfig()), config_(setupConfig()) { + filter_.setDecoderFilterCallbacks(decoder_callbacks_); + } + + GressFilterConfigSharedPtr setupConfig() { + GressPbConfig proto_config; + proto_config.set_instance("foo"); + proto_config.set_self_namespace("alice"); + proto_config.set_add_origin_source(true); + proto_config.set_max_logging_body_size_per_reqeuest(5); + auto rh = proto_config.add_rewrite_host_config(); + rh->set_header("kuscia-Host"); + rh->set_rewrite_policy(RewriteHost::RewriteHostWithHeader); + return GressFilterConfigSharedPtr(new GressFilterConfig(proto_config)); + } + + void enableRecordBody() { + std::string kRecordBody(KusciaCommon::HeaderKeyRecordBody.get()); + Http::TestRequestHeaderMapImpl headers{{kRecordBody, "true"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, false)); + } + + GressFilter filter_; + GressFilterConfigSharedPtr config_; + NiceMock decoder_callbacks_; }; TEST_F(GressFilterTest, EmptyHost) { - Http::TestRequestHeaderMapImpl headers; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); - KusciaHeaderChecker::checkRequestHeaders(headers, - ExpectHeaders{{kHost, ""}, - {kOrginSource, config_->selfNamespace()}}); + Http::TestRequestHeaderMapImpl headers; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); + KusciaHeaderChecker::checkRequestHeaders( + headers, ExpectHeaders{{kHost, ""}, {kOrginSource, config_->selfNamespace()}}); } TEST_F(GressFilterTest, OtherHost) { - Http::TestRequestHeaderMapImpl headers{{kHost, "baidu.com"}}; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); - KusciaHeaderChecker::checkRequestHeaders(headers, - ExpectHeaders{{kHost, "baidu.com"}, - {kOrginSource, config_->selfNamespace()}}); + Http::TestRequestHeaderMapImpl headers{{kHost, "baidu.com"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); + KusciaHeaderChecker::checkRequestHeaders( + headers, ExpectHeaders{{kHost, "baidu.com"}, {kOrginSource, config_->selfNamespace()}}); } TEST_F(GressFilterTest, RewriteHost) { - std::string source(KusciaCommon::HeaderKeyKusciaHost.get()); - std::string host = "baidu.com"; - Http::TestRequestHeaderMapImpl headers{{source, host}}; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); - EXPECT_EQ("baidu.com", headers.getHostValue()); + std::string source(KusciaCommon::HeaderKeyKusciaHost.get()); + std::string host = "baidu.com"; + Http::TestRequestHeaderMapImpl headers{{source, host}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); + EXPECT_EQ("baidu.com", headers.getHostValue()); } TEST_F(GressFilterTest, RecordBody) { - enableRecordBody(); - - Event::SimulatedTimeSystem test_time; - StreamInfo::StreamInfoImpl stream_info(Http::Protocol::Http2, test_time.timeSystem(), nullptr); - EXPECT_CALL(decoder_callbacks_, streamInfo()).WillRepeatedly(ReturnRef(stream_info)); - EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size()); - - std::string new_data("he"); - Envoy::Buffer::OwnedImpl request_body(new_data); - EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, false)); - EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size()); - - EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, true)); - EXPECT_EQ(1, stream_info.dynamicMetadata().filter_metadata_size()); - EXPECT_EQ("hehe", stream_info.dynamicMetadata() - .filter_metadata() - .at("envoy.kuscia") - .fields() - .at("request_body") - .string_value()); + enableRecordBody(); + + Event::SimulatedTimeSystem test_time; + StreamInfo::StreamInfoImpl stream_info(Http::Protocol::Http2, test_time.timeSystem(), nullptr); + EXPECT_CALL(decoder_callbacks_, streamInfo()).WillRepeatedly(ReturnRef(stream_info)); + EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size()); + + std::string new_data("he"); + Envoy::Buffer::OwnedImpl request_body(new_data); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, false)); + EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size()); + + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, true)); + EXPECT_EQ(1, stream_info.dynamicMetadata().filter_metadata_size()); + EXPECT_EQ("hehe", stream_info.dynamicMetadata() + .filter_metadata() + .at("envoy.kuscia") + .fields() + .at("request_body") + .string_value()); } TEST_F(GressFilterTest, RecordBodySizeExceed) { - enableRecordBody(); + enableRecordBody(); - Event::SimulatedTimeSystem test_time; - StreamInfo::StreamInfoImpl stream_info(Http::Protocol::Http2, test_time.timeSystem(), nullptr); - EXPECT_CALL(decoder_callbacks_, streamInfo()).WillRepeatedly(ReturnRef(stream_info)); - EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size()); + Event::SimulatedTimeSystem test_time; + StreamInfo::StreamInfoImpl stream_info(Http::Protocol::Http2, test_time.timeSystem(), nullptr); + EXPECT_CALL(decoder_callbacks_, streamInfo()).WillRepeatedly(ReturnRef(stream_info)); + EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size()); - std::string new_data("hee"); - Envoy::Buffer::OwnedImpl request_body(new_data); - EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, false)); - EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size()); + std::string new_data("hee"); + Envoy::Buffer::OwnedImpl request_body(new_data); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, false)); + EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size()); - EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, true)); - EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size()); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_body, true)); + EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size()); } } // namespace @@ -137,4 +134,3 @@ TEST_F(GressFilterTest, RecordBodySizeExceed) { } // namespace HttpFilters } // namespace Extensions } // namespace Envoy - diff --git a/kuscia/test/filters/http/kuscia_header_decorator/header_decorator_filter_test.cc b/kuscia/test/filters/http/kuscia_header_decorator/header_decorator_filter_test.cc index 0b283c4..84028ea 100755 --- a/kuscia/test/filters/http/kuscia_header_decorator/header_decorator_filter_test.cc +++ b/kuscia/test/filters/http/kuscia_header_decorator/header_decorator_filter_test.cc @@ -1,27 +1,26 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "test/mocks/http/mocks.h" #include "source/common/stream_info/stream_info_impl.h" +#include "test/mocks/http/mocks.h" -#include "kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.h" -#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" #include "kuscia/api/filters/http/kuscia_header_decorator/v3/header_decorator.pb.h" +#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" +#include "kuscia/source/filters/http/kuscia_header_decorator/header_decorator_filter.h" #include "kuscia/test/filters/http/test_common/header_checker.h" @@ -35,43 +34,44 @@ using testing::_; using namespace Envoy::Extensions::HttpFilters::KusciaTest; class HeaderDecoratorFilterTest : public testing::Test { - public: - HeaderDecoratorFilterTest() : filter_(setupConfig()) {} +public: + HeaderDecoratorFilterTest() : filter_(setupConfig()) {} - HeaderDecoratorPbConfig setupConfig() { - HeaderDecoratorPbConfig proto_config; - auto source_header = proto_config.mutable_append_headers()->Add(); - source_header->set_source("alice"); - auto header1 = source_header->mutable_headers()->Add(); - header1->set_key("k1"); - header1->set_value("v1"); + HeaderDecoratorPbConfig setupConfig() { + HeaderDecoratorPbConfig proto_config; + auto source_header = proto_config.mutable_append_headers()->Add(); + source_header->set_source("alice"); + auto header1 = source_header->mutable_headers()->Add(); + header1->set_key("k1"); + header1->set_value("v1"); - auto header2 = source_header->mutable_headers()->Add(); - header2->set_key("k2"); - header2->set_value("v2"); + auto header2 = source_header->mutable_headers()->Add(); + header2->set_key("k2"); + header2->set_value("v2"); - return proto_config; - } + return proto_config; + } - HeaderDecoratorFilter filter_; + HeaderDecoratorFilter filter_; }; TEST_F(HeaderDecoratorFilterTest, append_with_empty_header) { - Http::TestRequestHeaderMapImpl headers{{kKusciaSource, "alice"}}; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); - KusciaHeaderChecker::checkRequestHeaders(headers, ExpectHeaders{{"k1", "v1"}, {"k2", "v2"}}); + Http::TestRequestHeaderMapImpl headers{{kKusciaSource, "alice"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); + KusciaHeaderChecker::checkRequestHeaders(headers, ExpectHeaders{{"k1", "v1"}, {"k2", "v2"}}); } TEST_F(HeaderDecoratorFilterTest, append_with_unempty_header) { - Http::TestRequestHeaderMapImpl headers{{"k1", "v3"}, {kKusciaSource, "alice"}}; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); - KusciaHeaderChecker::checkRequestHeaders(headers, ExpectHeaders{{"k1", "v3"}, {"k2", "v2"}}); + Http::TestRequestHeaderMapImpl headers{{"k1", "v3"}, {kKusciaSource, "alice"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); + KusciaHeaderChecker::checkRequestHeaders(headers, ExpectHeaders{{"k1", "v3"}, {"k2", "v2"}}); } TEST_F(HeaderDecoratorFilterTest, umatch_source) { - Http::TestRequestHeaderMapImpl headers{{kKusciaSource, "bob"}}; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); - KusciaHeaderChecker::checkRequestHeaders(headers, ExpectHeaders{{"k1", "v1", false}, {"k2", "v2", false}}); + Http::TestRequestHeaderMapImpl headers{{kKusciaSource, "bob"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); + KusciaHeaderChecker::checkRequestHeaders( + headers, ExpectHeaders{{"k1", "v1", false}, {"k2", "v2", false}}); } } // namespace diff --git a/kuscia/test/filters/http/kuscia_token_auth/token_auth_filter_test.cc b/kuscia/test/filters/http/kuscia_token_auth/token_auth_filter_test.cc index 2c1de5c..539790c 100755 --- a/kuscia/test/filters/http/kuscia_token_auth/token_auth_filter_test.cc +++ b/kuscia/test/filters/http/kuscia_token_auth/token_auth_filter_test.cc @@ -1,27 +1,26 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "test/mocks/http/mocks.h" #include "source/common/stream_info/stream_info_impl.h" +#include "test/mocks/http/mocks.h" -#include "kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.h" -#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" #include "kuscia/api/filters/http/kuscia_token_auth/v3/token_auth.pb.h" +#include "kuscia/source/filters/http/kuscia_common/kuscia_header.h" +#include "kuscia/source/filters/http/kuscia_token_auth/token_auth_filter.h" #include "kuscia/test/filters/http/test_common/header_checker.h" @@ -37,51 +36,42 @@ const std::string Token1("token1"); const std::string Token2("token2"); class TokenAuthFilterTest : public testing::Test { - public: - TokenAuthFilterTest() : filter_(setupConfig()) { - filter_.setDecoderFilterCallbacks(decoder_callbacks_); - } - - TokenAuthConfigSharedPtr setupConfig() { - TokenAuthPbConfig proto_config; - auto token_auth = proto_config.mutable_source_token_list()->Add(); - token_auth->set_source("alice"); - token_auth->add_tokens(Token1); - token_auth->add_tokens(Token2); - return std::make_shared(proto_config); - } - - TokenAuthFilter filter_; - NiceMock decoder_callbacks_; +public: + TokenAuthFilterTest() : filter_(setupConfig()) { + filter_.setDecoderFilterCallbacks(decoder_callbacks_); + } + + TokenAuthConfigSharedPtr setupConfig() { + TokenAuthPbConfig proto_config; + auto token_auth = proto_config.mutable_source_token_list()->Add(); + token_auth->set_source("alice"); + token_auth->add_tokens(Token1); + token_auth->add_tokens(Token2); + return std::make_shared(proto_config); + } + + TokenAuthFilter filter_; + NiceMock decoder_callbacks_; }; TEST_F(TokenAuthFilterTest, AuthSuccWithToken1) { - Http::TestRequestHeaderMapImpl headers{ - {kKusciaSource, "alice"}, - {kKusciaToken, Token1}}; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); + Http::TestRequestHeaderMapImpl headers{{kKusciaSource, "alice"}, {kKusciaToken, Token1}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); } - TEST_F(TokenAuthFilterTest, AuthSuccWithToken2) { - Http::TestRequestHeaderMapImpl headers{ - {kKusciaSource, "alice"}, - {kKusciaToken, Token2}}; - EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); + Http::TestRequestHeaderMapImpl headers{{kKusciaSource, "alice"}, {kKusciaToken, Token2}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers, true)); } TEST_F(TokenAuthFilterTest, AuthInvlidSource) { - Http::TestRequestHeaderMapImpl headers{ - {kKusciaSource, "bob"}, - {kKusciaToken, Token2}}; - EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, filter_.decodeHeaders(headers, true)); + Http::TestRequestHeaderMapImpl headers{{kKusciaSource, "bob"}, {kKusciaToken, Token2}}; + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, filter_.decodeHeaders(headers, true)); } TEST_F(TokenAuthFilterTest, AuthInvlidToken) { - Http::TestRequestHeaderMapImpl headers{ - {kKusciaSource, "alice"}, - {kKusciaToken, "Token3"}}; - EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, filter_.decodeHeaders(headers, true)); + Http::TestRequestHeaderMapImpl headers{{kKusciaSource, "alice"}, {kKusciaToken, "Token3"}}; + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, filter_.decodeHeaders(headers, true)); } } // namespace diff --git a/kuscia/test/filters/http/test_common/header_checker.h b/kuscia/test/filters/http/test_common/header_checker.h index c5f50ec..41a0e57 100755 --- a/kuscia/test/filters/http/test_common/header_checker.h +++ b/kuscia/test/filters/http/test_common/header_checker.h @@ -1,18 +1,17 @@ // Copyright 2023 Ant Group Co., Ltd. -// +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include #include "gmock/gmock.h" @@ -39,35 +38,40 @@ const std::string kKusciaToken(KusciaCommon::HeaderKeyKusciaToken.get()); const std::string kOriginSource(KusciaCommon::HeaderKeyOriginSource.get()); struct ExpectHeader { - absl::string_view key; - absl::string_view value; - bool equal = true; - bool exist = true; + absl::string_view key; + absl::string_view value; + bool equal = true; + bool exist = true; }; using ExpectHeaders = std::vector; class KusciaHeaderChecker { - public: - static void checkRequestHeaders(const Http::TestRequestHeaderMapImpl& headers, - const ExpectHeaders& expects) { - for (const auto& iter : expects) { - auto result = headers.getByKey(Http::LowerCaseString(iter.key)).value_or(std::string()); - EXPECT_EQ(result == iter.value, iter.equal); - } +public: + static void checkRequestHeaders(const Http::TestRequestHeaderMapImpl& headers, + const ExpectHeaders& expects) { + for (const auto& iter : expects) { + auto result_sv = headers.get(Http::LowerCaseString(iter.key)); + absl::string_view result; + if (!result_sv.empty()) { + result = result_sv[0]->value().getStringView(); + } + EXPECT_EQ(result == iter.value, iter.equal); } + } - static void checkResponseHeaders(const Http::TestRequestHeaderMapImpl& headers, const ExpectHeaders& expects) { - for (const auto& iter : expects) { - auto result = headers.get(Http::LowerCaseString(iter.key)); - if (!iter.exist) { - EXPECT_TRUE(result.size() == 0); - } else { - EXPECT_TRUE(result.size() == 1 && result[0] != nullptr); - EXPECT_EQ(result[0]->value() == iter.value, iter.equal); - } - } + static void checkResponseHeaders(const Http::TestRequestHeaderMapImpl& headers, + const ExpectHeaders& expects) { + for (const auto& iter : expects) { + auto result = headers.get(Http::LowerCaseString(iter.key)); + if (!iter.exist) { + EXPECT_TRUE(result.size() == 0); + } else { + EXPECT_TRUE(result.size() == 1 && result[0] != nullptr); + EXPECT_EQ(result[0]->value() == iter.value, iter.equal); + } } + } }; } // namespace KusciaTest