diff --git a/proxy/merging.go b/proxy/merging.go index 8949495a2..7f415173b 100644 --- a/proxy/merging.go +++ b/proxy/merging.go @@ -81,7 +81,37 @@ func parallelMerge(timeout time.Duration, rc ResponseCombiner, next ...Proxy) Pr var reMergeKey = regexp.MustCompile(`\{\{\.Resp(\d+)_([\d\w-_\.]+)\}\}`) +func isBlocking(i int, deps [][]int) bool { + for _, dep := range deps { + for _, j := range dep { + if i == j { + return true + } + } + } + return false +} + func sequentialMerge(patterns []string, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy { + deps := make([][]int, len(patterns)) + matches := make([][][]string, len(patterns)) + for i, pattern := range patterns { + matches[i] = reMergeKey.FindAllStringSubmatch(pattern, -1) + deps[i] = make([]int, 0, len(matches)) + for _, match := range matches[i] { + if rNum, err := strconv.Atoi(match[1]); err == nil { + var found bool + for _, j := range deps[i] { + if j == rNum { + found = true + } + } + if !found { + deps[i] = append(deps[i], rNum) + } + } + } + } return func(ctx context.Context, request *Request) (*Response, error) { localCtx, cancel := context.WithTimeout(ctx, timeout) @@ -89,72 +119,109 @@ func sequentialMerge(patterns []string, timeout time.Duration, rc ResponseCombin out := make(chan *Response, 1) errCh := make(chan error, 1) + type partResult struct { i int; err error } + ch := make(chan partResult, len(next)) + done := make([]chan struct {}, len(next)) + for i := range next { + done[i] = make(chan struct {}, 1) + } + acc := newIncrementalMergeAccumulator(len(next), rc) - TxLoop: - for i, n := range next { - if i > 0 { - for _, match := range reMergeKey.FindAllStringSubmatch(patterns[i], -1) { - if len(match) > 1 { - rNum, err := strconv.Atoi(match[1]) - if err != nil || rNum >= i || parts[rNum] == nil { - continue - } - key := "Resp" + match[1] + "_" + match[2] - - var v interface{} - var ok bool - - data := parts[rNum].Data - keys := strings.Split(match[2], ".") - if len(keys) > 1 { - for _, k := range keys[:len(keys)-1] { - v, ok = data[k] - if !ok { - break - } - switch clean := v.(type) { - case map[string]interface{}: - data = clean - default: - break + + for i := range next { + go func(i int) { + n := next[i] + for _, j := range deps[i] { + select{ + case <-done[j]: + case <-localCtx.Done(): + return + } + } + var req *Request = request + if i > 0 { + req = CloneRequest(request) + for _, match := range matches[i] { + if len(match) > 1 { + rNum, err := strconv.Atoi(match[1]) + if err != nil || rNum >= i || parts[rNum] == nil { + continue + } + key := "Resp" + match[1] + "_" + match[2] + + var v interface{} + var ok bool + + data := parts[rNum].Data + keys := strings.Split(match[2], ".") + if len(keys) > 1 { + for _, k := range keys[:len(keys)-1] { + v, ok = data[k] + if !ok { + break + } + switch clean := v.(type) { + case map[string]interface{}: + data = clean + default: + break + } } } - } - v, ok = data[keys[len(keys)-1]] - if !ok { - continue - } - switch clean := v.(type) { - case string: - request.Params[key] = clean - case int: - request.Params[key] = strconv.Itoa(clean) - case float64: - request.Params[key] = strconv.FormatFloat(clean, 'E', -1, 32) - case bool: - request.Params[key] = strconv.FormatBool(clean) - default: - request.Params[key] = fmt.Sprintf("%v", v) + v, ok = data[keys[len(keys)-1]] + if !ok { + continue + } + switch clean := v.(type) { + case string: + req.Params[key] = clean + case int: + req.Params[key] = strconv.Itoa(clean) + case float64: + req.Params[key] = strconv.FormatFloat(clean, 'E', -1, 32) + case bool: + req.Params[key] = strconv.FormatBool(clean) + default: + req.Params[key] = fmt.Sprintf("%v", v) + } } } } - } - requestPart(localCtx, n, request, out, errCh) - select { - case err := <-errCh: - if i == 0 { - cancel() - return nil, err + requestPart(localCtx, n, req, out, errCh) + select { + case err := <-errCh: + ch <- partResult{i, err} + case response := <-out: + parts[i] = response + close(done[i]) + if !response.IsComplete { + cancel() + return + } + ch <- partResult{i, nil} } - acc.Merge(nil, err) - break TxLoop - case response := <-out: - acc.Merge(response, nil) - if !response.IsComplete { - break TxLoop + }(i) + } + + for _ = range next { + select { + case res := <-ch: + if i, err := res.i, res.err; err != nil { + acc.Merge(nil, err) + if isBlocking(i, deps) { + cancel() + break + } } - parts[i] = response + case <-localCtx.Done(): + break + } + } + + for _, part := range parts { + if part != nil { + acc.Merge(part, nil) } } @@ -201,7 +268,14 @@ func (i *incrementalMergeAccumulator) Merge(res *Response, err error) { func (i *incrementalMergeAccumulator) Result() (*Response, error) { if i.data == nil { - return &Response{Data: make(map[string]interface{}, 0), IsComplete: false}, newMergeError(i.errs) + err := newMergeError(i.errs) + + // none succeeded + if len(i.errs) == 1 { + return nil, err + } + + return &Response{Data: make(map[string]interface{}, 0), IsComplete: false}, err } if i.pending != 0 || len(i.errs) != 0 { @@ -235,6 +309,8 @@ func requestPart(ctx context.Context, next Proxy, request *Request, out chan<- * func newMergeError(errs []error) error { if len(errs) == 0 { return nil + } else if len(errs) == 1 { + return errs[0] } return mergeError{errs} }