Skip to content

Commit

Permalink
fix regression that prevented setting config slices with env variables (
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 committed Oct 4, 2021
1 parent 0d15e27 commit b70a4bf
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 14 deletions.
37 changes: 33 additions & 4 deletions internal/conf/authmethod.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,39 @@ package conf
import (
"encoding/json"
"fmt"
"strings"

"github.com/aler9/gortsplib/pkg/headers"
)

func unmarshalStringSlice(b []byte) ([]string, error) {
var in interface{}
if err := json.Unmarshal(b, &in); err != nil {
return nil, err
}

var slice []string

switch it := in.(type) {
case string: // from environment variables
slice = strings.Split(it, ",")

case []interface{}: // from yaml
for _, e := range it {
et, ok := e.(string)
if !ok {
return nil, fmt.Errorf("cannot unmarshal from %T", e)
}
slice = append(slice, et)
}

default:
return nil, fmt.Errorf("cannot unmarshal from %T", in)
}

return slice, nil
}

// AuthMethods is the authMethods parameter.
type AuthMethods []headers.AuthMethod

Expand All @@ -29,12 +58,12 @@ func (d AuthMethods) MarshalJSON() ([]byte, error) {

// UnmarshalJSON unmarshals a AuthMethods from JSON.
func (d *AuthMethods) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
slice, err := unmarshalStringSlice(b)
if err != nil {
return err
}

for _, v := range in {
for _, v := range slice {
switch v {
case "basic":
*d = append(*d, headers.AuthBasic)
Expand All @@ -43,7 +72,7 @@ func (d *AuthMethods) UnmarshalJSON(b []byte) error {
*d = append(*d, headers.AuthDigest)

default:
return fmt.Errorf("invalid authentication method: %s", in)
return fmt.Errorf("invalid authentication method: %s", v)
}
}

Expand Down
5 changes: 5 additions & 0 deletions internal/conf/conf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ func TestConfFromFileAndEnv(t *testing.T) {
os.Setenv("RTSP_PATHS_CAM1_SOURCE", "rtsp://testing")
defer os.Unsetenv("RTSP_PATHS_CAM1_SOURCE")

os.Setenv("RTSP_PROTOCOLS", "tcp")
defer os.Unsetenv("RTSP_PROTOCOLS")

tmpf, err := writeTempFile([]byte("{}"))
require.NoError(t, err)
defer os.Remove(tmpf)
Expand All @@ -94,6 +97,8 @@ func TestConfFromFileAndEnv(t *testing.T) {
require.NoError(t, err)
require.Equal(t, true, hasFile)

require.Equal(t, Protocols{ProtocolTCP: {}}, conf.Protocols)

pa, ok := conf.Paths["cam1"]
require.Equal(t, true, ok)
require.Equal(t, &PathConf{
Expand Down
8 changes: 4 additions & 4 deletions internal/conf/ipsornets.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ func (d IPsOrNets) MarshalJSON() ([]byte, error) {

// UnmarshalJSON unmarshals a IPsOrNets from JSON.
func (d *IPsOrNets) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
slice, err := unmarshalStringSlice(b)
if err != nil {
return err
}

if len(in) == 0 {
if len(slice) == 0 {
return nil
}

for _, t := range in {
for _, t := range slice {
if _, ipnet, err := net.ParseCIDR(t); err == nil {
*d = append(*d, ipnet)
} else if ip := net.ParseIP(t); ip != nil {
Expand Down
6 changes: 3 additions & 3 deletions internal/conf/logdestination.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ func (d LogDestinations) MarshalJSON() ([]byte, error) {

// UnmarshalJSON unmarshals a LogDestinations from JSON.
func (d *LogDestinations) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
slice, err := unmarshalStringSlice(b)
if err != nil {
return err
}

*d = make(LogDestinations)

for _, proto := range in {
for _, proto := range slice {
switch proto {
case "stdout":
(*d)[logger.DestinationStdout] = struct{}{}
Expand Down
6 changes: 3 additions & 3 deletions internal/conf/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ func (d Protocols) MarshalJSON() ([]byte, error) {

// UnmarshalJSON unmarshals a Protocols from JSON.
func (d *Protocols) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
slice, err := unmarshalStringSlice(b)
if err != nil {
return err
}

*d = make(Protocols)

for _, proto := range in {
for _, proto := range slice {
switch proto {
case "udp":
(*d)[ProtocolUDP] = struct{}{}
Expand Down
2 changes: 2 additions & 0 deletions internal/core/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func TestAPIConfigSet(t *testing.T) {
err := httpRequest(http.MethodPost, "http://localhost:9997/v1/config/set", map[string]interface{}{
"rtmpDisable": true,
"readTimeout": "7s",
"protocols": []string{"tcp"},
}, nil)
require.NoError(t, err)

Expand All @@ -82,6 +83,7 @@ func TestAPIConfigSet(t *testing.T) {
require.NoError(t, err)
require.Equal(t, true, out["rtmpDisable"])
require.Equal(t, "7s", out["readTimeout"])
require.Equal(t, []interface{}{"tcp"}, out["protocols"])
}

func TestAPIConfigPathsAdd(t *testing.T) {
Expand Down

0 comments on commit b70a4bf

Please sign in to comment.