Skip to content

Commit

Permalink
Merge pull request #51 from boris1993/ipv6
Browse files Browse the repository at this point in the history
[bug fix] correctly comparing 2 IP addresses
  • Loading branch information
boris1993 authored Nov 28, 2020
2 parents f4bce40 + ea1c4e1 commit 0f647ec
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 92 deletions.
33 changes: 16 additions & 17 deletions cmd/dnsupdater/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ package main
import (
"errors"
"flag"
"github.com/boris1993/dnsupdater/internal/configs"
"github.com/boris1993/dnsupdater/internal/constants"
"github.com/boris1993/dnsupdater/internal/common"
"github.com/boris1993/dnsupdater/internal/helper/aliyun"
"github.com/boris1993/dnsupdater/internal/helper/cloudflare"
log "github.com/sirupsen/logrus"
Expand All @@ -16,7 +15,7 @@ import (
func main() {
var err error

config, err := configs.Get()
config, err := common.GetConfig()
if err != nil {
log.Fatalln(err)
}
Expand All @@ -31,7 +30,7 @@ func main() {
// the currentIPv6Address will be an empty string
var currentIPv6Address = ""
if config.System.IPv6AddrAPI == "" {
log.Info(constants.MsgIPv6Disabled)
log.Info(common.MsgIPv6Disabled)
} else {
currentIPv6Address, err = getCurrentIPv6Address(*config)
if err != nil {
Expand Down Expand Up @@ -60,27 +59,27 @@ func main() {
}

func init() {
flag.StringVar(&configs.Path, "config", "", "Path to the config file.")
flag.BoolVar(&configs.Debug, "debug", false, "Enable debug logging.")
flag.StringVar(&common.ConfigFilePath, "config", "", "Path to the config file.")
flag.BoolVar(&common.Debug, "debug", false, "Enable debug logging.")

flag.Parse()

log.SetFormatter(&log.TextFormatter{DisableLevelTruncation: true})

if configs.Debug == true {
if common.Debug == true {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.InfoLevel)
}
}

// getCurrentIPv4Address returns the external IP address for your network.
func getCurrentIPv4Address(config configs.Config) (string, error) {
func getCurrentIPv4Address(config common.Config) (string, error) {
if config.System.IPAddrAPI == "" {
return "", errors.New(constants.ErrIPAddressFetchingAPIEmpty)
return "", errors.New(common.ErrIPAddressFetchingAPIEmpty)
}

log.Println(constants.MsgCheckingCurrentIPv4Addr)
log.Println(common.MsgCheckingCurrentIPv4Addr)

//region fetch your IPv4 address
resp, err := http.Get(config.System.IPAddrAPI)
Expand All @@ -93,7 +92,7 @@ func getCurrentIPv4Address(config configs.Config) (string, error) {
err := resp.Body.Close()

if err != nil {
log.Errorln(constants.ErrCloseHTTPConnectionFail, err)
log.Errorln(common.ErrCloseHTTPConnectionFail, err)
}
}()

Expand All @@ -106,19 +105,19 @@ func getCurrentIPv4Address(config configs.Config) (string, error) {
ipAddress := string(body)
//endregion

log.Println(constants.MsgHeaderCurrentIPv4Addr, ipAddress)
log.Println(common.MsgHeaderCurrentIPv4Addr, ipAddress)

return ipAddress, nil
}

// getCurrentIPv6Address returns the external IPv6 address for your network.
// Typically this should be your "temporary" IPv6 address.
func getCurrentIPv6Address(config configs.Config) (string, error) {
func getCurrentIPv6Address(config common.Config) (string, error) {
if config.System.IPv6AddrAPI == "" {
return "", errors.New(constants.ErrIPAddressFetchingAPIEmpty)
return "", errors.New(common.ErrIPAddressFetchingAPIEmpty)
}

log.Println(constants.MsgCheckingCurrentIPv6Addr)
log.Println(common.MsgCheckingCurrentIPv6Addr)

resp, err := http.Get(config.System.IPv6AddrAPI)
if err != nil {
Expand All @@ -130,7 +129,7 @@ func getCurrentIPv6Address(config configs.Config) (string, error) {
err := resp.Body.Close()

if err != nil {
log.Errorln(constants.ErrCloseHTTPConnectionFail, err)
log.Errorln(common.ErrCloseHTTPConnectionFail, err)
}
}()

Expand All @@ -142,7 +141,7 @@ func getCurrentIPv6Address(config configs.Config) (string, error) {
// Body only contains the IP address
ipv6Address := string(body)

log.Println(constants.MsgHeaderCurrentIPv6Addr, ipv6Address)
log.Println(common.MsgHeaderCurrentIPv6Addr, ipv6Address)

return ipv6Address, nil
}
15 changes: 7 additions & 8 deletions internal/configs/config.go → internal/common/config.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// Package conf provides all models needed by this programme.
package configs
package common

import (
"github.com/boris1993/dnsupdater/internal/constants"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"io/ioutil"
Expand All @@ -15,7 +14,7 @@ var once = new(sync.Once)

var Debug bool

var Path string
var ConfigFilePath string
var conf Config
var errorInInitConfig error

Expand Down Expand Up @@ -52,7 +51,7 @@ type AliDNS struct {
DomainType string `yaml:"DomainType"`
}

func Get() (*Config, error) {
func GetConfig() (*Config, error) {
once.Do(func() {
err := initConfig()

Expand All @@ -71,18 +70,18 @@ func Get() (*Config, error) {

// initConfig reads the Config.yaml and saves the properties in a variable.
func initConfig() error {
if Path == "" {
if ConfigFilePath == "" {
absPath, err := filepath.Abs(filepath.Dir(os.Args[0]))
if err != nil {
return err
}

Path = filepath.Join(absPath, "config.yaml")
ConfigFilePath = filepath.Join(absPath, "config.yaml")
}

log.Println(constants.MsgHeaderLoadingConfig, Path)
log.Println(MsgHeaderLoadingConfig, ConfigFilePath)

bytes, err := ioutil.ReadFile(Path)
bytes, err := ioutil.ReadFile(ConfigFilePath)

if err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package configs
package common

import (
"os"
Expand All @@ -16,13 +16,13 @@ func TestGet(t *testing.T) {

func testGetSuccess(t *testing.T) {
Debug = true
Path = testResourcePath + "/test_config.yaml"
if _, err := os.Stat(Path); os.IsNotExist(err) {
ConfigFilePath = testResourcePath + "/test_config.yaml"
if _, err := os.Stat(ConfigFilePath); os.IsNotExist(err) {
t.Errorf("test_config.yaml doesn't exist")
return
}

config, err := Get()
config, err := GetConfig()
if err != nil {
t.Error(err)
return
Expand All @@ -49,9 +49,9 @@ func testGetSuccess(t *testing.T) {

func testGetFail(t *testing.T) {
Debug = true
Path = testResourcePath + "/non_existent_config.yaml"
ConfigFilePath = testResourcePath + "/non_existent_config.yaml"

_, err := Get()
_, err := GetConfig()
if err == nil {
t.Error("TestGetFail should fail")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Package constants contains all constants needed in this programme
package constants
package common

const MsgHeaderDNSRecordUpdateSuccessful = "Successfully updated the DNS record"
const MsgHeaderCurrentIPv4Addr = "Current IPv4 address is:"
Expand Down
15 changes: 15 additions & 0 deletions internal/common/ip_address_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package common

import "net"

// CompareAddresses compares 2 given IP address string and see if they are the same IP address
func CompareAddresses(address1 string, address2 string) bool {
ipAddr1 := net.ParseIP(address1)
ipAddr2 := net.ParseIP(address2)

if ipAddr1 == nil || ipAddr2 == nil {
return false
}

return ipAddr1.Equal(ipAddr2)
}
36 changes: 36 additions & 0 deletions internal/common/ip_address_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package common

import "testing"

func TestCompareAddresses(t *testing.T) {
var sameAddress bool
var failTestMessage = "%s and %s should be the same address"

sameAddress = CompareAddresses("192.168.1.1", "192.168.001.001")
if !sameAddress {
t.Errorf(failTestMessage, "192.168.1.1", "192.168.001.001")
}

sameAddress = CompareAddresses(
"2001:0db8:0000:0000:0000:ff00:0042:8329", "2001:db8:0:0:0:ff00:42:8329")
if !sameAddress {
t.Errorf(failTestMessage, "2001:0db8:0000:0000:0000:ff00:0042:8329", "2001:db8:0:0:0:ff00:42:8329")
}

sameAddress = CompareAddresses(
"2001:0db8:0000:0000:0000:ff00:0042:8329", "2001:db8::ff00:42:8329")
if !sameAddress {
t.Errorf(failTestMessage, "2001:0db8:0000:0000:0000:ff00:0042:8329", "2001:db8::ff00:42:8329")
}

sameAddress = CompareAddresses(
"2001:db8:0:0:0:ff00:42:8329", "2001:db8::ff00:42:8329")
if !sameAddress {
t.Errorf(failTestMessage, "2001:db8:0:0:0:ff00:42:8329", "2001:db8::ff00:42:8329")
}

sameAddress = CompareAddresses("not valid address", "not valid address")
if sameAddress {
t.Error("Should return false when comparing invalid IP addresses")
}
}
Loading

0 comments on commit 0f647ec

Please sign in to comment.