diff --git a/pkg/provider/azure_loadbalancer.go b/pkg/provider/azure_loadbalancer.go index 5d25ce124e..f1e13b65dc 100644 --- a/pkg/provider/azure_loadbalancer.go +++ b/pkg/provider/azure_loadbalancer.go @@ -30,6 +30,7 @@ import ( "unicode" "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/go-autorest/autorest/azure" v1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -143,7 +144,7 @@ func (az *Cloud) reconcileService(_ context.Context, clusterName string, service serviceIPs := lbIPsPrimaryPIPs klog.V(2).Infof("reconcileService: reconciling security group for service %q with IPs %q, wantLb = true", serviceName, serviceIPs) - if _, err := az.reconcileSecurityGroup(clusterName, service, ptr.Deref(lb.Name, ""), serviceIPs, true /* wantLb */); err != nil { + if _, err := az.reconcileSecurityGroup(clusterName, service, ptr.Deref(lb.Name, ""), fipConfigs, serviceIPs, true /* wantLb */); err != nil { klog.Errorf("reconcileSecurityGroup(%s) failed: %#v", serviceName, err) return nil, err } @@ -322,11 +323,14 @@ func (az *Cloud) EnsureLoadBalancerDeleted(_ context.Context, clusterName string } serviceIPsToCleanup := lbIPsPrimaryPIPs klog.V(2).Infof("EnsureLoadBalancerDeleted: reconciling security group for service %q with IPs %q, wantLb = false", serviceName, serviceIPsToCleanup) - var lbName string - if lb != nil { - lbName = ptr.Deref(lb.Name, "") + + _, _, fipConfigs, err := az.getServiceLoadBalancerStatus(service, lb) + if err != nil { + klog.Errorf("EnsureLoadBalancerDeleted: getServiceLoadBalancerStatus(%s) failed: %v", serviceName, err) + return err } - _, err = az.reconcileSecurityGroup(clusterName, service, lbName, serviceIPsToCleanup, false /* wantLb */) + + _, err = az.reconcileSecurityGroup(clusterName, service, ptr.Deref(lb.Name, ""), fipConfigs, serviceIPsToCleanup, false /* wantLb */) if err != nil { return err } @@ -2804,12 +2808,124 @@ func (az *Cloud) getExpectedHAModeLoadBalancingRuleProperties( return props, nil } +func (az *Cloud) listServicesByPublicIPs(pips []network.PublicIPAddress) ([]*v1.Service, error) { + logger := klog.Background().WithName("listServicesByPublicIPs") + var ( + svcNames []string + rv []*v1.Service + ) + + for _, pip := range pips { + if pip.ID == nil { // FIXME: it should not be nil + continue + } + resourceID, err := azure.ParseResourceID(*pip.ID) + if err != nil { // FIXME: it should never happen except for testing + continue + } + logger.V(4).Info("fetching public IPs", "pip-id", pip.ID) + pip, _, err := az.getPublicIPAddress(resourceID.ResourceGroup, resourceID.ResourceName, azcache.CacheReadTypeDefault) + + if err != nil { + return nil, err + } + + logger.V(4).Info("fetched public IP", "pip", pip) + v := getServiceFromPIPServiceTags(pip.Tags) + if v != "" { + parts := strings.Split(strings.TrimSpace(v), ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + svcNames = append(svcNames, p) + } + } + } + + for _, svcName := range svcNames { + parts := strings.Split(svcName, "/") + if len(parts) != 2 { + continue + } + ns, svcName := parts[0], parts[1] + + logger.Info("fetching service from lister", "ns", ns, "service-name", svcName) + svc, err := az.serviceLister.Services(ns).Get(svcName) + if err != nil { + return nil, fmt.Errorf("get service error: %w", err) + } + + rv = append(rv, svc) + } + + return rv, nil +} + +// listSharedIPPortMapping lists the shared IP port mapping for the service excluding the service itself. +// There are scenarios where multiple services share the same public IP, +// and in order to clean up the security rules, we need to know the port mapping of the shared IP. +func (az *Cloud) listSharedIPPortMapping(svc *v1.Service, publicIPs []network.PublicIPAddress) (map[network.SecurityRuleProtocol][]int32, error) { + var ( + logger = klog.Background().WithName("listSharedIPPortMapping").WithValues("service-name", svc.Name) + rv = make(map[network.SecurityRuleProtocol][]int32) + convertProtocol = func(protocol v1.Protocol) (network.SecurityRuleProtocol, error) { + switch protocol { + case v1.ProtocolTCP: + return network.SecurityRuleProtocolTCP, nil + case v1.ProtocolUDP: + return network.SecurityRuleProtocolUDP, nil + case v1.ProtocolSCTP: + return network.SecurityRuleProtocolAsterisk, nil + } + return "", fmt.Errorf("unsupported protocol %s", protocol) + } + ) + + services, err := az.listServicesByPublicIPs(publicIPs) + if err != nil { + logger.Error(err, "Failed to list services by public IPs") + return nil, err + } + + for _, s := range services { + logger.V(4).Info("iterating service", "service", s.Name, "namespace", s.Namespace) + if svc.Namespace == s.Namespace && svc.Name == s.Name { + // skip the service itself + continue + } + + for _, port := range s.Spec.Ports { + protocol, err := convertProtocol(port.Protocol) + if err != nil { + return nil, err + } + + var p int32 + if consts.IsK8sServiceDisableLoadBalancerFloatingIP(s) { + p = port.NodePort + } else { + p = port.Port + } + logger.V(4).Info("adding port mapping", "protocol", protocol, "port", p) + + rv[protocol] = append(rv[protocol], p) + } + } + + logger.V(4).Info("retain port mapping", "port-mapping", rv) + + return rv, nil +} + // This reconciles the Network Security Group similar to how the LB is reconciled. // This entails adding required, missing SecurityRules and removing stale rules. func (az *Cloud) reconcileSecurityGroup( clusterName string, service *v1.Service, - lbName string, lbIPs []string, - wantLb bool, + lbName string, + fipConfigs []*network.FrontendIPConfiguration, + lbIPs []string, wantLb bool, ) (*network.SecurityGroup, error) { logger := klog.Background().WithName("reconcileSecurityGroup"). WithValues("cluster", clusterName). @@ -2822,6 +2938,13 @@ func (az *Cloud) reconcileSecurityGroup( return nil, fmt.Errorf("no load balancer IP for setting up security rules for service %s", service.Name) } + var publicIPs []network.PublicIPAddress + for _, fipConfig := range fipConfigs { + if fipConfig.PublicIPAddress != nil { + publicIPs = append(publicIPs, *fipConfig.PublicIPAddress) + } + } + additionalIPs, err := loadbalancer.AdditionalPublicIPs(service) if wantLb && err != nil { return nil, fmt.Errorf("unable to get additional public IPs: %w", err) @@ -2900,7 +3023,16 @@ func (az *Cloud) reconcileSecurityGroup( dstIPv6Addresses := append(lbIPv6Addresses, backendIPv6Addresses...) dstIPv6Addresses = append(dstIPv6Addresses, additionalIPv6Addresses...) - accessControl.CleanSecurityGroup(dstIPv4Addresses, dstIPv6Addresses) + retainPortRanges, err := az.listSharedIPPortMapping(service, publicIPs) + if err != nil { + logger.Error(err, "Failed to list retain port ranges") + return nil, err + } + + if err := accessControl.CleanSecurityGroup(dstIPv4Addresses, dstIPv6Addresses, retainPortRanges); err != nil { + logger.Error(err, "Failed to clean security group") + return nil, err + } } if wantLb { diff --git a/pkg/provider/azure_loadbalancer_accesscontrol_test.go b/pkg/provider/azure_loadbalancer_accesscontrol_test.go index 3a21a23aa3..cc63267c23 100644 --- a/pkg/provider/azure_loadbalancer_accesscontrol_test.go +++ b/pkg/provider/azure_loadbalancer_accesscontrol_test.go @@ -26,9 +26,14 @@ import ( "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/informers" + "k8s.io/client-go/kubernetes/fake" "k8s.io/utils/ptr" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/securitygroupclient/mocksecuritygroupclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer" @@ -82,7 +87,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - sg, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + sg, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) testutil.ExpectEqualInJSON(t, azureFx.SecurityGroup().Build(), sg) }) @@ -169,7 +174,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -255,7 +260,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) }) @@ -330,7 +335,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) }) @@ -412,7 +417,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -487,7 +492,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -567,7 +572,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -665,7 +670,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -745,7 +750,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -838,7 +843,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) }) @@ -906,7 +911,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -968,7 +973,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -1034,7 +1039,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -1118,7 +1123,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -1185,7 +1190,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -1264,7 +1269,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -1361,7 +1366,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) }) @@ -1415,7 +1420,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"foo"}), - DestinationPortRanges: ptr.To([]string{"4000-6000"}), + DestinationPortRanges: ptr.To([]string{"4000", "6000"}), DestinationAddressPrefixes: ptr.To(azureFx.LoadBalancer().Addresses()), // Should remove the rule Priority: ptr.To(int32(4003)), }, @@ -1428,7 +1433,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"bar"}), - DestinationPortRanges: ptr.To([]string{"5000-6000"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), DestinationAddressPrefixes: ptr.To(append(azureFx.LoadBalancer().Addresses(), "bar")), // Should keep bar but clean the rest Priority: ptr.To(int32(4004)), }, @@ -1522,7 +1527,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"bar"}), - DestinationPortRanges: ptr.To([]string{"5000-6000"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), DestinationAddressPrefixes: ptr.To([]string{"bar"}), // Should keep bar but clean the rest Priority: ptr.To(int32(4004)), }, @@ -1545,7 +1550,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -1676,7 +1681,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -1729,7 +1734,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"foo"}), - DestinationPortRanges: ptr.To([]string{"4000-6000"}), + DestinationPortRanges: ptr.To([]string{"4000", "6000"}), DestinationAddressPrefixes: ptr.To(azureFx.LoadBalancer().Addresses()), // Should remove the rule Priority: ptr.To(int32(4003)), }, @@ -1742,11 +1747,12 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"bar"}), - DestinationPortRanges: ptr.To([]string{"5000-6000"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), DestinationAddressPrefixes: ptr.To(append(azureFx.LoadBalancer().Addresses(), "bar")), // Should keep bar but clean the rest Priority: ptr.To(int32(4004)), }, }, + azureFx. AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, []string{allowedServiceTag}, k8sFx.Service().TCPPorts()). WithPriority(505). @@ -1836,7 +1842,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"bar"}), - DestinationPortRanges: ptr.To([]string{"5000-6000"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), DestinationAddressPrefixes: ptr.To([]string{"bar"}), // Should keep bar but clean the rest Priority: ptr.To(int32(4004)), }, @@ -1876,7 +1882,304 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) + assert.NoError(t, err) + }) + + t.Run("update rules - keep retain ports", func(t *testing.T) { + + sharedIPSvc := k8sFx.Service(). + WithNamespace("ns-02"). + WithName("svc-02"). + Build() + + sharedIPSvc.Spec.Ports = []v1.ServicePort{ + { + Name: "port-1", + Protocol: v1.ProtocolTCP, + Port: 18000, + NodePort: 48000, + }, + { + Name: "port2", + Protocol: v1.ProtocolTCP, + Port: 19000, + NodePort: 49000, + }, + } + var ( + ctrl = gomock.NewController(t) + az = GetTestCloud(ctrl) + publicIPAddressClient = az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + securityGroupClient = az.SecurityGroupsClient.(*mocksecuritygroupclient.MockInterface) + loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) + loadBalancer = azureFx.LoadBalancer().Build() + + allowedServiceTag = azureFx.ServiceTag() + allowedIPv4Ranges = fx.RandomIPv4PrefixStrings(3) + allowedIPv6Ranges = fx.RandomIPv6PrefixStrings(3) + allowedRanges = append(allowedIPv4Ranges, allowedIPv6Ranges...) + svc = k8sFx.Service(). + WithNamespace("ns-01"). + WithName("svc-01"). + WithAllowedServiceTags(allowedServiceTag). + WithAllowedIPRanges(allowedRanges...). + Build() + + kubeClient = fake.NewSimpleClientset(&sharedIPSvc, &svc) + informerFactory = informers.NewSharedInformerFactory(kubeClient, 0) + svcLister = informerFactory.Core().V1().Services().Lister() + + pip = fx.Azure().PublicIPAddress("pip1"). + WithTag(consts.ServiceTagKey, fmt.Sprintf("%s/%s,%s/%s", svc.Namespace, svc.Name, sharedIPSvc.Namespace, sharedIPSvc.Name)). + Build() + frontendIPConfigurations = []*network.FrontendIPConfiguration{ + { + FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &pip, + }, + }, + } + ) + defer ctrl.Finish() + + az.serviceLister = svcLister + informerFactory.Start(wait.NeverStop) + informerFactory.WaitForCacheSync(wait.NeverStop) + + var ( + noiseRules = azureFx.NoiseSecurityRules(10) + staleRules = []network.SecurityRule{ + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, []string{allowedServiceTag}, []int32{8000}). + WithPriority(4000). + WithDestination(azureFx.LoadBalancer().IPv4Addresses()...). // Should remove the rule + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, []string{allowedServiceTag}, []int32{6000, 3000}). + WithPriority(4001). + WithDestination(append(azureFx.LoadBalancer().IPv4Addresses(), "foo", "bar")...). // Should keep foo and bar but clean the rest + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolUDP, iputil.IPv6, allowedIPv6Ranges, []int32{9000}). + WithPriority(4002). + WithDestination(append(azureFx.LoadBalancer().IPv6Addresses(), "baz")...). // Should keep baz but clean the rest + Build(), + + { + Name: ptr.To("foo"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourcePortRange: ptr.To("*"), + SourceAddressPrefixes: ptr.To([]string{"foo"}), + DestinationPortRanges: ptr.To([]string{"4000", "6000"}), + DestinationAddressPrefixes: ptr.To(azureFx.LoadBalancer().Addresses()), // Should remove the rule + Priority: ptr.To(int32(4003)), + }, + }, + { + Name: ptr.To("bar"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourcePortRange: ptr.To("*"), + SourceAddressPrefixes: ptr.To([]string{"bar"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), + DestinationAddressPrefixes: ptr.To(append(azureFx.LoadBalancer().Addresses(), "bar")), // Should keep bar but clean the rest + Priority: ptr.To(int32(4004)), + }, + }, + + { + Name: ptr.To("baz"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourcePortRange: ptr.To("*"), + SourceAddressPrefixes: ptr.To([]string{"baz"}), + DestinationPortRanges: ptr.To([]string{"18000", "19000"}), + DestinationAddressPrefixes: ptr.To(append(azureFx.LoadBalancer().Addresses(), "baz")), // Should keep all since the ports are retained + Priority: ptr.To(int32(4005)), + }, + }, + + { + Name: ptr.To("quo"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourcePortRange: ptr.To("*"), + SourceAddressPrefixes: ptr.To([]string{"quo"}), + DestinationPortRanges: ptr.To([]string{"18000", "19000", "20000"}), + DestinationAddressPrefixes: ptr.To(append(azureFx.LoadBalancer().Addresses(), "quo")), // Should split the rules + Priority: ptr.To(int32(4006)), + }, + }, + } + targetRules = []network.SecurityRule{ + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, []string{allowedServiceTag}, k8sFx.Service().TCPPorts()). + WithPriority(505). + WithDestination(azureFx.LoadBalancer().IPv4Addresses()...). + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, allowedIPv4Ranges, k8sFx.Service().TCPPorts()). + WithPriority(507). + WithDestination(azureFx.LoadBalancer().IPv4Addresses()...). + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv6, []string{allowedServiceTag}, k8sFx.Service().TCPPorts()). + WithPriority(509). + WithDestination(azureFx.LoadBalancer().IPv6Addresses()...). + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv6, allowedIPv6Ranges, k8sFx.Service().TCPPorts()). + WithPriority(520). + WithDestination(azureFx.LoadBalancer().IPv6Addresses()...). + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolUDP, iputil.IPv4, []string{allowedServiceTag}, k8sFx.Service().UDPPorts()). + WithPriority(530). + WithDestination(azureFx.LoadBalancer().IPv4Addresses()...). + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolUDP, iputil.IPv4, allowedIPv4Ranges, k8sFx.Service().UDPPorts()). + WithPriority(607). + WithDestination(azureFx.LoadBalancer().IPv4Addresses()...). + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolUDP, iputil.IPv6, []string{allowedServiceTag}, k8sFx.Service().UDPPorts()). + WithPriority(709). + WithDestination(azureFx.LoadBalancer().IPv6Addresses()...). + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolUDP, iputil.IPv6, allowedIPv6Ranges, k8sFx.Service().UDPPorts()). + WithPriority(3000). + WithDestination(azureFx.LoadBalancer().IPv6Addresses()...). + Build(), + } + ) + + publicIPAddressClient.EXPECT(). + List(gomock.Any(), gomock.Any()). + Return([]network.PublicIPAddress{pip}, nil). + Times(1) + + securityGroup := azureFx.SecurityGroup().WithRules( + append(append(noiseRules, targetRules...), staleRules...), + ).Build() + + securityGroupClient.EXPECT(). + Get(gomock.Any(), az.ResourceGroup, az.SecurityGroupName, gomock.Any()). + Return(securityGroup, nil). + Times(1) + securityGroupClient.EXPECT(). + CreateOrUpdate(gomock.Any(), az.ResourceGroup, az.SecurityGroupName, gomock.Any(), gomock.Any()). + DoAndReturn(func( + ctx context.Context, + resourceGroupName, securityGroupName string, + properties network.SecurityGroup, + etag string, + ) *retry.Error { + rules := append(append(noiseRules, targetRules...), + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, []string{allowedServiceTag}, []int32{6000, 3000}). + WithPriority(4001). + WithDestination("foo", "bar"). // Should keep foo and bar but clean the rest + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolUDP, iputil.IPv6, allowedIPv6Ranges, []int32{9000}). + WithPriority(4002). + WithDestination("baz"). // Should keep baz but clean the rest + Build(), + + network.SecurityRule{ + Name: ptr.To("bar"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourcePortRange: ptr.To("*"), + SourceAddressPrefixes: ptr.To([]string{"bar"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), + DestinationAddressPrefixes: ptr.To([]string{"bar"}), // Should keep bar but clean the rest + Priority: ptr.To(int32(4004)), + }, + }, + + network.SecurityRule{ + Name: ptr.To("baz"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourcePortRange: ptr.To("*"), + SourceAddressPrefixes: ptr.To([]string{"baz"}), + DestinationPortRanges: ptr.To([]string{"18000", "19000"}), + DestinationAddressPrefixes: ptr.To(append(azureFx.LoadBalancer().Addresses(), "baz")), // Should keep all since the ports are retained + Priority: ptr.To(int32(4005)), + }, + }, + + network.SecurityRule{ + Name: ptr.To("quo"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourcePortRange: ptr.To("*"), + SourceAddressPrefixes: ptr.To([]string{"quo"}), + DestinationPortRanges: ptr.To([]string{"18000", "19000", "20000"}), + DestinationAddressPrefixes: ptr.To([]string{"quo"}), // Should split the rules + Priority: ptr.To(int32(4006)), + }, + }, + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv4, []string{"quo"}, []int32{18000, 19000}). + WithPriority(500). + WithDestination(azureFx.LoadBalancer().IPv4Addresses()...). + Build(), + + azureFx. + AllowSecurityRule(network.SecurityRuleProtocolTCP, iputil.IPv6, []string{"quo"}, []int32{18000, 19000}). + WithPriority(501). + WithDestination(azureFx.LoadBalancer().IPv6Addresses()...). + Build(), + ) + + testutil.ExpectExactSecurityRules(t, &properties, rules) + + return nil + }).Times(1) + loadBalancerClient.EXPECT(). + Get(gomock.Any(), az.ResourceGroup, *loadBalancer.Name, gomock.Any()). + Return(loadBalancer, nil). + Times(1) + loadBalancerBackendPool.EXPECT(). + GetBackendPrivateIPs(ClusterName, &svc, &loadBalancer). + Return( + []string{}, []string{}, + ). + Times(1) + + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, frontendIPConfigurations, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.NoError(t, err) }) @@ -1947,7 +2250,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"foo"}), - DestinationPortRanges: ptr.To([]string{"4000-6000"}), + DestinationPortRanges: ptr.To([]string{"4000", "6000"}), DestinationAddressPrefixes: ptr.To(azureFx.LoadBalancer().Addresses()), // Should remove the rule Priority: ptr.To(int32(4003)), }, @@ -1960,7 +2263,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"bar"}), - DestinationPortRanges: ptr.To([]string{"5000-6000"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), DestinationAddressPrefixes: ptr.To(append(azureFx.LoadBalancer().Addresses(), "bar")), // Should keep bar but clean the rest Priority: ptr.To(int32(4004)), }, @@ -2024,7 +2327,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"bar"}), - DestinationPortRanges: ptr.To([]string{"5000-6000"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), DestinationAddressPrefixes: ptr.To([]string{"bar"}), // Should keep bar but clean the rest Priority: ptr.To(int32(4004)), }, @@ -2053,7 +2356,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), false) // deleting + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), false) // deleting assert.NoError(t, err) }) @@ -2131,7 +2434,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"foo"}), - DestinationPortRanges: ptr.To([]string{"4000-6000"}), + DestinationPortRanges: ptr.To([]string{"4000", "6000"}), DestinationAddressPrefixes: ptr.To(azureFx.LoadBalancer().Addresses()), // Should remove the rule Priority: ptr.To(int32(4003)), }, @@ -2144,7 +2447,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"bar"}), - DestinationPortRanges: ptr.To([]string{"5000-6000"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), DestinationAddressPrefixes: ptr.To(append(azureFx.LoadBalancer().Addresses(), "bar")), // Should keep bar but clean the rest Priority: ptr.To(int32(4004)), }, @@ -2208,7 +2511,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"bar"}), - DestinationPortRanges: ptr.To([]string{"5000-6000"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), DestinationAddressPrefixes: ptr.To([]string{"bar"}), // Should keep bar but clean the rest Priority: ptr.To(int32(4004)), }, @@ -2237,7 +2540,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), false) // deleting + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), false) // deleting assert.NoError(t, err) }) @@ -2307,7 +2610,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"foo"}), - DestinationPortRanges: ptr.To([]string{"4000-6000"}), + DestinationPortRanges: ptr.To([]string{"4000", "6000"}), DestinationAddressPrefixes: ptr.To(azureFx.LoadBalancer().Addresses()), // Should remove the rule Priority: ptr.To(int32(4003)), }, @@ -2320,7 +2623,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Direction: network.SecurityRuleDirectionInbound, SourcePortRange: ptr.To("*"), SourceAddressPrefixes: ptr.To([]string{"bar"}), - DestinationPortRanges: ptr.To([]string{"5000-6000"}), + DestinationPortRanges: ptr.To([]string{"5000", "6000"}), DestinationAddressPrefixes: ptr.To(append(azureFx.LoadBalancer().Addresses(), "bar")), // Should keep bar but clean the rest Priority: ptr.To(int32(4004)), }, @@ -2360,7 +2663,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Return(loadBalancer, &retry.Error{HTTPStatusCode: http.StatusNotFound}). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, false) // deleting + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, nil, false) // deleting assert.NoError(t, err) }) @@ -2389,7 +2692,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Return(securityGroup, nil). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.Error(t, err) assert.ErrorIs(t, err, loadbalancer.ErrSetBothLoadBalancerSourceRangesAndAllowedIPRanges) }) @@ -2414,7 +2717,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Return(securityGroup, expectedErr). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.Error(t, err) assert.ErrorIs(t, err, expectedErr.RawError) }) @@ -2444,7 +2747,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Return(loadBalancer, expectedErr). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.Error(t, err) assert.ErrorIs(t, err, expectedErr.RawError) }) @@ -2487,7 +2790,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.Error(t, err) assert.ErrorIs(t, err, expectedErr.RawError) }) @@ -2522,7 +2825,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ). Times(1) - _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) + _, err := az.reconcileSecurityGroup(ClusterName, &svc, *loadBalancer.Name, nil, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.Error(t, err) }) }) diff --git a/pkg/provider/loadbalancer/accesscontrol.go b/pkg/provider/loadbalancer/accesscontrol.go index fcdb92c92d..5879c09455 100644 --- a/pkg/provider/loadbalancer/accesscontrol.go +++ b/pkg/provider/loadbalancer/accesscontrol.go @@ -265,21 +265,41 @@ func (ac *AccessControl) PatchSecurityGroup(dstIPv4Addresses, dstIPv6Addresses [ } // CleanSecurityGroup removes the given IP addresses from the SecurityGroup. -func (ac *AccessControl) CleanSecurityGroup(dstIPv4Addresses, dstIPv6Addresses []netip.Addr) { +func (ac *AccessControl) CleanSecurityGroup( + dstIPv4Addresses, dstIPv6Addresses []netip.Addr, + retainPortRanges map[network.SecurityRuleProtocol][]int32, +) error { logger := ac.logger.WithName("CleanSecurityGroup"). WithValues("num-dst-ipv4-addresses", len(dstIPv4Addresses)). WithValues("num-dst-ipv6-addresses", len(dstIPv6Addresses)) logger.V(10).Info("Start cleaning") var ( - prefixes = fnutil.Map(func(addr netip.Addr) string { - return addr.String() - }, append(dstIPv4Addresses, dstIPv6Addresses...)) + ipv4Prefixes = fnutil.Map(func(addr netip.Addr) string { return addr.String() }, dstIPv4Addresses) + ipv6Prefixes = fnutil.Map(func(addr netip.Addr) string { return addr.String() }, dstIPv6Addresses) ) - ac.sgHelper.RemoveDestinationPrefixesFromRules(prefixes) + protocols := []network.SecurityRuleProtocol{ + network.SecurityRuleProtocolTCP, + network.SecurityRuleProtocolUDP, + network.SecurityRuleProtocolAsterisk, + } + + for _, protocol := range protocols { + retainDstPorts := retainPortRanges[protocol] + if err := ac.sgHelper.RemoveDestinationFromRules(protocol, ipv4Prefixes, retainDstPorts); err != nil { + logger.Error(err, "Failed to remove IPv4 destination from rules") + return err + } + + if err := ac.sgHelper.RemoveDestinationFromRules(protocol, ipv6Prefixes, retainDstPorts); err != nil { + logger.Error(err, "Failed to remove IPv6 destination from rules") + return err + } + } logger.V(10).Info("Completed cleaning") + return nil } // SecurityGroup returns the SecurityGroup object with patched rules and indicates if the rules had been changed. diff --git a/pkg/provider/loadbalancer/accesscontrol_test.go b/pkg/provider/loadbalancer/accesscontrol_test.go index 6c7960caf6..879d7a36f5 100644 --- a/pkg/provider/loadbalancer/accesscontrol_test.go +++ b/pkg/provider/loadbalancer/accesscontrol_test.go @@ -1090,7 +1090,7 @@ func TestAccessControl_CleanSecurityGroup(t *testing.T) { ) assert.NoError(t, err) - ac.CleanSecurityGroup(fx.RandomIPv4Addresses(2), fx.RandomIPv6Addresses(2)) + assert.NoError(t, ac.CleanSecurityGroup(fx.RandomIPv4Addresses(2), fx.RandomIPv6Addresses(2), make(map[network.SecurityRuleProtocol][]int32))) _, updated, err := ac.SecurityGroup() assert.NoError(t, err) assert.False(t, updated) @@ -1137,7 +1137,7 @@ func TestAccessControl_CleanSecurityGroup(t *testing.T) { ) assert.NoError(t, err) - ac.CleanSecurityGroup(dstIPv4Addresses, nil) + assert.NoError(t, ac.CleanSecurityGroup(dstIPv4Addresses, nil, make(map[network.SecurityRuleProtocol][]int32))) _, updated, err := ac.SecurityGroup() assert.NoError(t, err) assert.False(t, updated) @@ -1197,7 +1197,7 @@ func TestAccessControl_CleanSecurityGroup(t *testing.T) { ) assert.NoError(t, err) - ac.CleanSecurityGroup(dstIPv4Addresses, nil) + assert.NoError(t, ac.CleanSecurityGroup(dstIPv4Addresses, nil, make(map[network.SecurityRuleProtocol][]int32))) outputSG, updated, err := ac.SecurityGroup() assert.NoError(t, err) assert.True(t, updated) @@ -1312,7 +1312,7 @@ func TestAccessControl_CleanSecurityGroup(t *testing.T) { ) assert.NoError(t, err) - ac.CleanSecurityGroup(dstIPv4Addresses, nil) + assert.NoError(t, ac.CleanSecurityGroup(dstIPv4Addresses, nil, make(map[network.SecurityRuleProtocol][]int32))) outputSG, updated, err := ac.SecurityGroup() assert.NoError(t, err) @@ -1346,4 +1346,121 @@ func TestAccessControl_CleanSecurityGroup(t *testing.T) { }, }, outputSG.SecurityRules) }) + + t.Run("it should split rules if retainPorts is set", func(t *testing.T) { + var ( + rules = []network.SecurityRule{ + { + Name: ptr.To("test-rule-0"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_foo", "src_bar"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2", "192.168.0.1"}), + DestinationPortRanges: ptr.To([]string{"80", "443"}), + Priority: ptr.To(int32(500)), + }, + }, + { + Name: ptr.To("test-rule-1"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_baz", "src_quo"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"20.0.0.1", "192.168.0.1", "192.168.0.2", "20.0.0.2"}), + DestinationPortRanges: ptr.To([]string{"53", "54", "55", "56"}), + Priority: ptr.To(int32(501)), + }, + }, + { + Name: ptr.To("test-rule-2"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolAsterisk, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"*"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"8.8.8.8"}), + DestinationPortRanges: ptr.To([]string{"5000"}), + Priority: ptr.To(int32(502)), + }, + }, + } + + sg = azureFx.SecurityGroup().WithRules(rules).Build() + dstIPv4Addresses = []netip.Addr{ + netip.MustParseAddr("192.168.0.1"), + netip.MustParseAddr("192.168.0.2"), + } + svc = fx.Kubernetes().Service().Build() + ac, err = NewAccessControl(&svc, &sg) + ) + assert.NoError(t, err) + + assert.NoError(t, ac.CleanSecurityGroup(dstIPv4Addresses, nil, map[network.SecurityRuleProtocol][]int32{ + network.SecurityRuleProtocolUDP: {56, 53}, + })) + outputSG, updated, err := ac.SecurityGroup() + assert.NoError(t, err) + assert.True(t, updated) + + testutil.ExpectEqualInJSON(t, []network.SecurityRule{ + { + Name: ptr.To("test-rule-0"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_foo", "src_bar"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2"}), + DestinationPortRanges: ptr.To([]string{"80", "443"}), + Priority: ptr.To(int32(500)), + }, + }, + { + Name: ptr.To("test-rule-1"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_baz", "src_quo"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"20.0.0.1", "20.0.0.2"}), + DestinationPortRanges: ptr.To([]string{"53", "54", "55", "56"}), + Priority: ptr.To(int32(501)), + }, + }, + { + Name: ptr.To("test-rule-2"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolAsterisk, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"*"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"8.8.8.8"}), + DestinationPortRanges: ptr.To([]string{"5000"}), + Priority: ptr.To(int32(502)), + }, + }, + { + Name: ptr.To("k8s-azure-lb_allow_IPv4_648b18e18a92d1a4b415033da37c79a5"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_baz", "src_quo"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"192.168.0.1", "192.168.0.2"}), + DestinationPortRanges: ptr.To([]string{"53", "56"}), // 53 and 56 are retained + Priority: ptr.To(int32(503)), + }, + }, + }, outputSG.SecurityRules) + }) } diff --git a/pkg/provider/loadbalancer/fnutil/slice.go b/pkg/provider/loadbalancer/fnutil/slice.go index 210fe1d97e..ad94878466 100644 --- a/pkg/provider/loadbalancer/fnutil/slice.go +++ b/pkg/provider/loadbalancer/fnutil/slice.go @@ -42,3 +42,22 @@ func IsAll[T any](f func(T) bool, xs []T) bool { } return true } + +func IndexSet[T comparable](xs []T) map[T]bool { + rv := make(map[T]bool, len(xs)) + for _, x := range xs { + rv[x] = true + } + return rv +} + +func Intersection[T comparable](xs, ys []T) []T { + ysSet := IndexSet(ys) + var rv []T + for _, x := range xs { + if ysSet[x] { + rv = append(rv, x) + } + } + return rv +} diff --git a/pkg/provider/loadbalancer/securitygroup/securitygroup.go b/pkg/provider/loadbalancer/securitygroup/securitygroup.go index 659944f989..2144b0f31c 100644 --- a/pkg/provider/loadbalancer/securitygroup/securitygroup.go +++ b/pkg/provider/loadbalancer/securitygroup/securitygroup.go @@ -18,13 +18,11 @@ package securitygroup import ( "bytes" - "crypto/md5" //nolint:gosec "encoding/json" "fmt" "net/netip" "sort" "strconv" - "strings" "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" "k8s.io/klog/v2" @@ -303,6 +301,72 @@ func (helper *RuleHelper) AddRuleForDenyAll(dstAddresses []netip.Addr) error { return nil } +// RemoveDestinationFromRules removes the given destination addresses from rules that match the given protocol and ports is in the retainDstPorts list. +// It may add a new rule if the original rule needs to be split. +func (helper *RuleHelper) RemoveDestinationFromRules( + protocol network.SecurityRuleProtocol, + dstPrefixes []string, + retainDstPorts []int32, +) error { + logger := helper.logger.WithName("RemoveDestinationFromRules").WithValues("protocol", protocol, "num-dst-prefixes", len(dstPrefixes)) + logger.V(10).Info("Cleaning destination from SecurityGroup") + + for _, rule := range helper.rules { + if rule.Protocol != protocol { + continue + } + + if err := helper.removeDestinationFromRule(rule, dstPrefixes, retainDstPorts); err != nil { + logger.Error(err, "Failed to remove destination from rule", "rule-name", *rule.Name) + } + } + + return nil +} + +func (helper *RuleHelper) removeDestinationFromRule(rule *network.SecurityRule, prefixes []string, retainDstPorts []int32) error { + logger := helper.logger.WithName("removeDestinationFromRule"). + WithValues("security-rule-name", rule.Name) + currentPorts, err := ListDestinationPortRanges(rule) + if err != nil { + // Skip the rule with invalid destination port ranges. + // NOTE: cloud-provider would not create allow rules with `*` or `4000-5000` as destination port ranges. + logger.Info("Skip because it contains `*` or port-ranges as destination port ranges.") + return nil + } + + var ( + prefixIndex = fnutil.IndexSet(prefixes) // Used to check whether the prefix should be removed. + currentPrefixes = ListDestinationPrefixes(rule) + + expectedPrefixes = fnutil.RemoveIf(func(p string) bool { return prefixIndex[p] }, currentPrefixes) // The prefixes to keep. + targetPrefixes = fnutil.Intersection(currentPrefixes, prefixes) // The prefixes to remove. + expectedPorts = fnutil.Intersection(currentPorts, retainDstPorts) // The ports to keep. + ) + + if len(targetPrefixes) == 0 || len(currentPorts) == len(expectedPorts) { + return nil + } + + // Update the prefixes + rule.DestinationAddressPrefix = nil + rule.DestinationAddressPrefixes = ptr.To(NormalizeSecurityRuleAddressPrefixes(expectedPrefixes)) + + if len(expectedPorts) == 0 { + // No additional ports are expected, no more actions are needed. + return nil + } + + // There are additional ports are expected, need to create a new rule for them. + addr, err := netip.ParseAddr(prefixes[0]) + if err != nil { + logger.Error(err, "Failed to parse dst IP address", "dst-ip", prefixes[0]) + return fmt.Errorf("parse prefix as IP address %q: %w", prefixes[0], err) + } + ipFamily := iputil.FamilyOfAddr(addr) + return helper.addAllowRule(rule.Protocol, ipFamily, ListSourcePrefixes(rule), prefixes, expectedPorts) +} + // RemoveDestinationPrefixesFromRules removes the given destination addresses from all rules. func (helper *RuleHelper) RemoveDestinationPrefixesFromRules(prefixes []string) { helper.logger.V(10).Info("Cleaning destination address prefixes from SecurityGroup", "num-dst-prefixes", len(prefixes)) @@ -395,28 +459,6 @@ func (helper *RuleHelper) SecurityGroup() (*network.SecurityGroup, bool, error) return rv, updated, nil } -// NormalizeSecurityRuleAddressPrefixes normalizes the given rule address prefixes. -func NormalizeSecurityRuleAddressPrefixes(vs []string) []string { - // Remove redundant addresses. - indexes := make(map[string]bool, len(vs)) - for _, v := range vs { - indexes[v] = true - } - rv := make([]string, 0, len(indexes)) - for k := range indexes { - rv = append(rv, k) - } - sort.Strings(rv) - return rv -} - -// NormalizeDestinationPortRanges normalizes the given destination port ranges. -func NormalizeDestinationPortRanges(dstPorts []int32) []string { - rv := fnutil.Map(func(p int32) string { return strconv.FormatInt(int64(p), 10) }, dstPorts) - sort.Strings(rv) - return rv -} - // makeSecurityGroupSnapshot returns a byte array as the snapshot of the given SecurityGroup. // It's used to check if the SecurityGroup had been changed. func makeSecurityGroupSnapshot(sg *network.SecurityGroup) []byte { @@ -426,37 +468,3 @@ func makeSecurityGroupSnapshot(sg *network.SecurityGroup) []byte { snapshot, _ := json.Marshal(sg) return snapshot } - -// GenerateAllowSecurityRuleName returns the AllowInbound rule name based on the given rule properties. -func GenerateAllowSecurityRuleName( - protocol network.SecurityRuleProtocol, - ipFamily iputil.Family, - srcPrefixes []string, - dstPorts []int32, -) string { - var ruleID string - { - dstPortRanges := fnutil.Map(func(p int32) string { return strconv.FormatInt(int64(p), 10) }, dstPorts) - // Generate rule ID from protocol, source prefixes and destination port ranges. - sort.Strings(srcPrefixes) - sort.Strings(dstPortRanges) - - v := strings.Join([]string{ - string(protocol), - strings.Join(srcPrefixes, ","), - strings.Join(dstPortRanges, ","), - }, "_") - - h := md5.New() //nolint:gosec - h.Write([]byte(v)) - - ruleID = fmt.Sprintf("%x", h.Sum(nil)) - } - - return strings.Join([]string{SecurityRuleNamePrefix, "allow", string(ipFamily), ruleID}, SecurityRuleNameSep) -} - -// GenerateDenyAllSecurityRuleName returns the DenyInbound rule name based on the given rule properties. -func GenerateDenyAllSecurityRuleName(ipFamily iputil.Family) string { - return strings.Join([]string{SecurityRuleNamePrefix, "deny-all", string(ipFamily)}, SecurityRuleNameSep) -} diff --git a/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go b/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go index c51d8ed9b6..9031630503 100644 --- a/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go +++ b/pkg/provider/loadbalancer/securitygroup/securitygroup_test.go @@ -1031,8 +1031,411 @@ func TestSecurityGroupHelper_AddRuleForDenyAll(t *testing.T) { }) } +func TestRuleHelper_RemoveDestinationFromRules(t *testing.T) { + fx := fixture.NewFixture() + + t.Run("it should not patch rules if no rules exist", func(t *testing.T) { + var ( + sg = fx.Azure().SecurityGroup().Build() + helper = ExpectNewSecurityGroupHelper(t, &sg) + dstAddresses = fnutil.Map(func(p netip.Addr) string { + return p.String() + }, fx.RandomIPv4Addresses(2)) + ) + err := helper.RemoveDestinationFromRules(network.SecurityRuleProtocolTCP, dstAddresses, []int32{}) + assert.NoError(t, err) + + _, updated, err := helper.SecurityGroup() + assert.NoError(t, err) + assert.False(t, updated) + }) + + t.Run("it should not patch rules if no rules match", func(t *testing.T) { + var ( + rules = []network.SecurityRule{ + { + Name: ptr.To("test-rule-0"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_foo", "src_bar"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2"}), + DestinationPortRanges: ptr.To([]string{"443", "80"}), + Priority: ptr.To(int32(500)), + }, + }, + { + Name: ptr.To("test-rule-1"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_baz", "src_quo"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"20.0.0.1", "20.0.0.2"}), + DestinationPortRanges: ptr.To([]string{"53"}), + Priority: ptr.To(int32(501)), + }, + }, + } + + sg = fx.Azure().SecurityGroup().WithRules(rules).Build() + helper = ExpectNewSecurityGroupHelper(t, &sg) + dstAddresses = []string{ + "192.168.0.1", + "192.168.0.2", + } + ) + assert.NoError(t, helper.RemoveDestinationFromRules(network.SecurityRuleProtocolTCP, dstAddresses, []int32{})) + assert.NoError(t, helper.RemoveDestinationFromRules(network.SecurityRuleProtocolUDP, dstAddresses, []int32{})) + + _, updated, err := helper.SecurityGroup() + assert.NoError(t, err) + assert.False(t, updated) + }) + + t.Run("it should patch the matched rules", func(t *testing.T) { + var ( + rules = []network.SecurityRule{ + { + Name: ptr.To("test-rule-0"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_foo", "src_bar"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2", "192.168.0.1"}), + DestinationPortRanges: ptr.To([]string{"443", "80"}), + Priority: ptr.To(int32(500)), + }, + }, + { + Name: ptr.To("test-rule-1"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_baz", "src_quo"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"20.0.0.1", "192.168.0.1", "192.168.0.2", "20.0.0.2"}), + DestinationPortRanges: ptr.To([]string{"53"}), + Priority: ptr.To(int32(501)), + }, + }, + { + Name: ptr.To("test-rule-2"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolAsterisk, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefix: ptr.To("*"), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"8.8.8.8"}), + DestinationPortRanges: ptr.To([]string{"5000"}), + Priority: ptr.To(int32(502)), + }, + }, + } + + sg = fx.Azure().SecurityGroup().WithRules(rules).Build() + helper = ExpectNewSecurityGroupHelper(t, &sg) + dstAddresses = []string{ + "192.168.0.1", + "192.168.0.2", + } + ) + + assert.NoError(t, helper.RemoveDestinationFromRules(network.SecurityRuleProtocolTCP, dstAddresses, []int32{})) + assert.NoError(t, helper.RemoveDestinationFromRules(network.SecurityRuleProtocolUDP, dstAddresses, []int32{})) + + outputSG, updated, err := helper.SecurityGroup() + assert.NoError(t, err) + assert.True(t, updated) + testutil.ExpectEqualInJSON(t, []network.SecurityRule{ + { + Name: ptr.To("test-rule-0"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_foo", "src_bar"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2"}), + DestinationPortRanges: ptr.To([]string{"443", "80"}), + Priority: ptr.To(int32(500)), + }, + }, + { + Name: ptr.To("test-rule-1"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_baz", "src_quo"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"20.0.0.1", "20.0.0.2"}), + DestinationPortRanges: ptr.To([]string{"53"}), + Priority: ptr.To(int32(501)), + }, + }, + { + Name: ptr.To("test-rule-2"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolAsterisk, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefix: ptr.To("*"), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"8.8.8.8"}), + DestinationPortRanges: ptr.To([]string{"5000"}), + Priority: ptr.To(int32(502)), + }, + }, + }, outputSG.SecurityRules) + }) + + t.Run("it should remove the matched rules if no destination addresses left", func(t *testing.T) { + var ( + rules = []network.SecurityRule{ + { + Name: ptr.To("test-rule-0"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_foo", "src_bar"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2", "192.168.0.1"}), + DestinationPortRanges: ptr.To([]string{"443", "80"}), + Priority: ptr.To(int32(500)), + }, + }, + { + Name: ptr.To("test-rule-1"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_baz", "src_quo"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"192.168.0.1", "192.168.0.2"}), + DestinationPortRanges: ptr.To([]string{"53"}), + Priority: ptr.To(int32(501)), + }, + }, + { + Name: ptr.To("test-rule-2"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefix: ptr.To("*"), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"8.8.8.8"}), + DestinationPortRanges: ptr.To([]string{"5000"}), + Priority: ptr.To(int32(502)), + }, + }, + { + Name: ptr.To("test-rule-3"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefix: ptr.To("*"), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefix: ptr.To("192.168.0.1"), + DestinationPortRanges: ptr.To([]string{"8000"}), + Priority: ptr.To(int32(2000)), + }, + }, + { + Name: ptr.To("test-rule-4"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefix: ptr.To("*"), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{}), + DestinationAddressPrefix: ptr.To("192.168.0.1"), + DestinationPortRanges: ptr.To([]string{"8000"}), + Priority: ptr.To(int32(2000)), + }, + }, + } + + sg = fx.Azure().SecurityGroup().WithRules(rules).Build() + helper = ExpectNewSecurityGroupHelper(t, &sg) + dstAddresses = []string{ + "192.168.0.1", + "192.168.0.2", + } + ) + assert.NoError(t, helper.RemoveDestinationFromRules(network.SecurityRuleProtocolTCP, dstAddresses, []int32{})) + assert.NoError(t, helper.RemoveDestinationFromRules(network.SecurityRuleProtocolUDP, dstAddresses, []int32{})) + + outputSG, updated, err := helper.SecurityGroup() + assert.NoError(t, err) + assert.True(t, updated) + testutil.ExpectEqualInJSON(t, []network.SecurityRule{ + { + Name: ptr.To("test-rule-0"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_foo", "src_bar"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2"}), + DestinationPortRanges: ptr.To([]string{"443", "80"}), + Priority: ptr.To(int32(500)), + }, + }, + { + Name: ptr.To("test-rule-2"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefix: ptr.To("*"), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"8.8.8.8"}), + DestinationPortRanges: ptr.To([]string{"5000"}), + Priority: ptr.To(int32(502)), + }, + }, + }, outputSG.SecurityRules) + }) + + t.Run("it should retain the port ranges if specified - all ports retained - nothing changed", func(t *testing.T) { + var ( + rules = []network.SecurityRule{ + { + Name: ptr.To("test-rule-0"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"src_foo", "src_bar"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2", "192.168.0.1"}), + DestinationPortRanges: ptr.To([]string{"443", "80"}), + Priority: ptr.To(int32(500)), + }, + }, + } + + sg = fx.Azure().SecurityGroup().WithRules(rules).Build() + helper = ExpectNewSecurityGroupHelper(t, &sg) + dstAddresses = []string{ + "10.0.0.1", + "10.0.0.2", + } + ) + + assert.NoError(t, helper.RemoveDestinationFromRules(network.SecurityRuleProtocolTCP, dstAddresses, []int32{443, 80})) + + _, updated, err := helper.SecurityGroup() + assert.NoError(t, err) + assert.False(t, updated) + }) + + t.Run("it should retain the port ranges if specified - part of ports retained - split the rule", func(t *testing.T) { + var ( + rules = []network.SecurityRule{ + { + Name: ptr.To("test-rule-0"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"bar", "foo"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2", "192.168.0.1"}), + DestinationPortRanges: ptr.To([]string{"443", "80"}), + Priority: ptr.To(int32(500)), + }, + }, + { + Name: ptr.To("test-rule-1"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, // Different protocol, should not be touched. + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"baz"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2", "192.168.0.1"}), + DestinationPortRanges: ptr.To([]string{"443", "80"}), + Priority: ptr.To(int32(501)), + }, + }, + } + + sg = fx.Azure().SecurityGroup().WithRules(rules).Build() + helper = ExpectNewSecurityGroupHelper(t, &sg) + dstAddresses = []string{ + "10.0.0.1", + "10.0.0.2", + } + ) + + assert.NoError(t, helper.RemoveDestinationFromRules(network.SecurityRuleProtocolTCP, dstAddresses, []int32{443})) + + outputSG, updated, err := helper.SecurityGroup() + assert.NoError(t, err) + assert.True(t, updated) + testutil.ExpectEqualInJSON(t, []network.SecurityRule{ + { + Name: ptr.To("test-rule-0"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"bar", "foo"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"192.168.0.1"}), + DestinationPortRanges: ptr.To([]string{"443", "80"}), + Priority: ptr.To(int32(500)), + }, + }, + { + Name: ptr.To("test-rule-1"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolUDP, // Different protocol, should not be touched. + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"baz"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2", "192.168.0.1"}), + DestinationPortRanges: ptr.To([]string{"443", "80"}), + Priority: ptr.To(int32(501)), + }, + }, + { + Name: ptr.To("k8s-azure-lb_allow_IPv4_b5ae07e8a4177ea2d37162cdf2badf8b"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + Access: network.SecurityRuleAccessAllow, + Direction: network.SecurityRuleDirectionInbound, + SourceAddressPrefixes: ptr.To([]string{"bar", "foo"}), + SourcePortRange: ptr.To("*"), + DestinationAddressPrefixes: ptr.To([]string{"10.0.0.1", "10.0.0.2"}), + DestinationPortRanges: ptr.To([]string{"443"}), + Priority: ptr.To(int32(502)), + }, + }, + }, outputSG.SecurityRules) + }) + +} + func TestSecurityGroupHelper_RemoveDstAddressesFromRules(t *testing.T) { fx := fixture.NewFixture() + t.Run("it should not patch rules if no rules exist", func(t *testing.T) { var ( sg = fx.Azure().SecurityGroup().Build() diff --git a/pkg/provider/loadbalancer/securitygroup/securityrule.go b/pkg/provider/loadbalancer/securitygroup/securityrule.go new file mode 100644 index 0000000000..bd230f5178 --- /dev/null +++ b/pkg/provider/loadbalancer/securitygroup/securityrule.go @@ -0,0 +1,129 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package securitygroup + +import ( + "crypto/md5" //nolint:gosec + "fmt" + "sort" + "strconv" + "strings" + + "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + + "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/fnutil" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer/iputil" +) + +// GenerateAllowSecurityRuleName returns the AllowInbound rule name based on the given rule properties. +func GenerateAllowSecurityRuleName( + protocol network.SecurityRuleProtocol, + ipFamily iputil.Family, + srcPrefixes []string, + dstPorts []int32, +) string { + var ruleID string + { + dstPortRanges := fnutil.Map(func(p int32) string { return strconv.FormatInt(int64(p), 10) }, dstPorts) + // Generate rule ID from protocol, source prefixes and destination port ranges. + sort.Strings(srcPrefixes) + sort.Strings(dstPortRanges) + + v := strings.Join([]string{ + string(protocol), + strings.Join(srcPrefixes, ","), + strings.Join(dstPortRanges, ","), + }, "_") + + h := md5.New() //nolint:gosec + h.Write([]byte(v)) + + ruleID = fmt.Sprintf("%x", h.Sum(nil)) + } + + return strings.Join([]string{SecurityRuleNamePrefix, "allow", string(ipFamily), ruleID}, SecurityRuleNameSep) +} + +// GenerateDenyAllSecurityRuleName returns the DenyInbound rule name based on the given rule properties. +func GenerateDenyAllSecurityRuleName(ipFamily iputil.Family) string { + return strings.Join([]string{SecurityRuleNamePrefix, "deny-all", string(ipFamily)}, SecurityRuleNameSep) +} + +// NormalizeSecurityRuleAddressPrefixes normalizes the given rule address prefixes. +func NormalizeSecurityRuleAddressPrefixes(vs []string) []string { + // Remove redundant addresses. + indexes := make(map[string]bool, len(vs)) + for _, v := range vs { + indexes[v] = true + } + rv := make([]string, 0, len(indexes)) + for k := range indexes { + rv = append(rv, k) + } + sort.Strings(rv) + return rv +} + +// NormalizeDestinationPortRanges normalizes the given destination port ranges. +func NormalizeDestinationPortRanges(dstPorts []int32) []string { + rv := fnutil.Map(func(p int32) string { return strconv.FormatInt(int64(p), 10) }, dstPorts) + sort.Strings(rv) + return rv +} + +func ListSourcePrefixes(r *network.SecurityRule) []string { + var rv []string + if r.SourceAddressPrefix != nil { + rv = append(rv, *r.SourceAddressPrefix) + } + if r.SourceAddressPrefixes != nil { + rv = append(rv, *r.SourceAddressPrefixes...) + } + return rv +} + +func ListDestinationPrefixes(r *network.SecurityRule) []string { + var rv []string + if r.DestinationAddressPrefix != nil { + rv = append(rv, *r.DestinationAddressPrefix) + } + if r.DestinationAddressPrefixes != nil { + rv = append(rv, *r.DestinationAddressPrefixes...) + } + return rv +} + +func ListDestinationPortRanges(r *network.SecurityRule) ([]int32, error) { + var values []string + if r.DestinationPortRange != nil { + values = append(values, *r.DestinationPortRange) + } + if r.DestinationPortRanges != nil { + values = append(values, *r.DestinationPortRanges...) + } + + rv := make([]int32, 0, len(values)) + for _, v := range values { + p, err := strconv.ParseInt(v, 10, 32) + if err != nil { + return nil, fmt.Errorf("parse port range %q: %w", v, err) + } + rv = append(rv, int32(p)) + } + + return rv, nil +} diff --git a/pkg/provider/loadbalancer/testutil/fixture/azure_publicip.go b/pkg/provider/loadbalancer/testutil/fixture/azure_publicip.go new file mode 100644 index 0000000000..f59c4bc9e3 --- /dev/null +++ b/pkg/provider/loadbalancer/testutil/fixture/azure_publicip.go @@ -0,0 +1,53 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fixture + +import ( + "fmt" + + "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "k8s.io/utils/ptr" +) + +func (f *AzureFixture) PublicIPAddress(name string) *AzurePublicIPAddressFixture { + const ( + SubscriptionID = "00000000-0000-0000-0000-000000000000" + ResourceGroup = "rg" + ) + + return &AzurePublicIPAddressFixture{ + pip: &network.PublicIPAddress{ + ID: ptr.To(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", SubscriptionID, ResourceGroup, name)), + Name: ptr.To(name), + Tags: make(map[string]*string), + PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{}, + }, + } +} + +type AzurePublicIPAddressFixture struct { + pip *network.PublicIPAddress +} + +func (f *AzurePublicIPAddressFixture) Build() network.PublicIPAddress { + return *f.pip +} + +func (f *AzurePublicIPAddressFixture) WithTag(key, value string) *AzurePublicIPAddressFixture { + f.pip.Tags[key] = ptr.To(value) + return f +} diff --git a/pkg/provider/loadbalancer/testutil/fixture/kubernetes.go b/pkg/provider/loadbalancer/testutil/fixture/kubernetes.go index da424a8d09..3b94a685bc 100644 --- a/pkg/provider/loadbalancer/testutil/fixture/kubernetes.go +++ b/pkg/provider/loadbalancer/testutil/fixture/kubernetes.go @@ -71,6 +71,16 @@ type KubernetesServiceFixture struct { svc v1.Service } +func (f *KubernetesServiceFixture) WithNamespace(ns string) *KubernetesServiceFixture { + f.svc.Namespace = ns + return f +} + +func (f *KubernetesServiceFixture) WithName(name string) *KubernetesServiceFixture { + f.svc.Name = name + return f +} + func (f *KubernetesServiceFixture) WithInternalEnabled() *KubernetesServiceFixture { f.svc.Annotations[consts.ServiceAnnotationLoadBalancerInternal] = "true" return f diff --git a/tests/e2e/network/network_security_group.go b/tests/e2e/network/network_security_group.go index 943ad02940..22fc16c1c6 100644 --- a/tests/e2e/network/network_security_group.go +++ b/tests/e2e/network/network_security_group.go @@ -707,6 +707,127 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( }) }) }) + + When("creating 2 LoadBalancer services with shared public IP", func() { + It("should add rules independently", func() { + + const ( + Deployment1Name = "app-01" + Deployment2Name = "app-02" + + Service1Name = "svc-01" + Service2Name = "svc-02" + ) + + var ( + app1Port int32 = 80 + app2Port int32 = 81 + replicas int32 = 2 + svc1IPv4s []netip.Addr + svc1IPv6s []netip.Addr + svc2IPs []netip.Addr + ) + + deployment1 := createDeploymentManifest(Deployment1Name, map[string]string{ + "app": Deployment1Name, + }, &app1Port, nil) + deployment1.Spec.Replicas = &replicas + _, err := k8sClient.AppsV1().Deployments(namespace.Name).Create(context.Background(), deployment1, metav1.CreateOptions{}) + Expect(err).NotTo(HaveOccurred()) + + deployment2 := createDeploymentManifest(Deployment2Name, map[string]string{ + "app": Deployment2Name, + }, &app2Port, nil) + deployment2.Spec.Replicas = &replicas + _, err = k8sClient.AppsV1().Deployments(namespace.Name).Create(context.Background(), deployment2, metav1.CreateOptions{}) + Expect(err).NotTo(HaveOccurred()) + + By("Creating service 1", func() { + var ( + labels = map[string]string{ + "app": Deployment1Name, + } + annotations = map[string]string{} + ports = []v1.ServicePort{{ + Port: app1Port, + TargetPort: intstr.FromInt32(app1Port), + }} + ) + rv := createAndExposeDefaultServiceWithAnnotation(k8sClient, azureClient.IPFamily, Service1Name, namespace.Name, labels, annotations, ports) + svc1IPv4s, svc1IPv6s = groupIPsByFamily(mustParseIPs(derefSliceOfStringPtr(rv))) + logger.Info("Created the first LoadBalancer service", "svc-name", Service1Name, "v4-IPs", svc1IPv4s, "v6-IPs", svc1IPv6s) + }) + + By("Creating service 2", func() { + var ( + labels = map[string]string{ + "app": Deployment2Name, + } + annotations = map[string]string{} + ports = []v1.ServicePort{{ + Port: app2Port, + TargetPort: intstr.FromInt32(app2Port), + }} + ) + var ip netip.Addr + if len(svc1IPv4s) > 0 { + ip = svc1IPv4s[0] + } + if len(svc1IPv6s) > 0 { + ip = svc1IPv6s[0] + } + + rv := createAndExposeDefaultServiceWithAnnotation(k8sClient, azureClient.IPFamily, Service2Name, namespace.Name, labels, annotations, ports, func(svc *v1.Service) error { + svc.Spec.LoadBalancerIP = ip.String() + return nil + }) + svc2IPs = mustParseIPs(derefSliceOfStringPtr(rv)) + logger.Info("Created the second LoadBalancer service", "svc-name", Service2Name, "IPs", svc2IPs) + Expect(svc2IPs).To(HaveLen(1)) + Expect(svc2IPs[0]).To(Equal(ip)) + }) + + var validator *SecurityGroupValidator + By("Getting the cluster security groups", func() { + rv, err := azureClient.GetClusterSecurityGroups() + Expect(err).NotTo(HaveOccurred()) + + validator = NewSecurityGroupValidator(rv) + }) + + By("Checking if the rule for allowing traffic for app 01", func() { + var ( + expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedDstPorts = []string{strconv.FormatInt(int64(app1Port), 10)} + ) + + By("Checking if the rule for allowing traffic from Internet exists") + + if len(svc1IPv4s) > 0 { + Expect( + validator.HasExactAllowRule(expectedProtocol, []string{"Internet"}, svc1IPv4s, expectedDstPorts), + ).To(BeTrue(), "Should not have a rule for allowing IPv4 traffic from Internet") + } + + if len(svc1IPv6s) > 0 { + Expect( + validator.HasExactAllowRule(expectedProtocol, []string{"Internet"}, svc1IPv6s, expectedDstPorts), + ).To(BeTrue(), "Should not have a rule for allowing IPv6 traffic from Internet") + } + }) + + By("Checking if the rule for allowing traffic for app 02", func() { + var ( + expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedDstPorts = []string{strconv.FormatInt(int64(app2Port), 10)} + ) + By("Checking if the rule for allowing traffic from Internet exists") + Expect( + validator.HasExactAllowRule(expectedProtocol, []string{"Internet"}, svc2IPs, expectedDstPorts), + ).To(BeTrue(), "Should not have a rule for allowing traffic from Internet") + }) + }) + }) }) type SecurityGroupValidator struct {