Skip to content

Commit

Permalink
Merge pull request projectdiscovery#592 from projectdiscovery/bugfix-…
Browse files Browse the repository at this point in the history
…progress-logic

Progress tracking logic
  • Loading branch information
ehsandeep authored Mar 4, 2021
2 parents 371f4be + b7c19e7 commit 0c5be83
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 75 deletions.
104 changes: 52 additions & 52 deletions v2/internal/progress/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (p *Progress) Init(hostCount int64, rulesCount int, requestCount int64) {
p.stats.AddCounter("total", uint64(requestCount))

if p.active {
if err := p.stats.Start(makePrintCallback(), p.tickDuration); err != nil {
if err := p.stats.Start(printCallback, p.tickDuration); err != nil {
gologger.Warning().Msgf("Couldn't start statistics: %s", err)
}
}
Expand All @@ -91,63 +91,61 @@ func (p *Progress) IncrementMatched() {
p.stats.IncrementCounter("matched", 1)
}

// DecrementRequests decrements the number of requests from total.
func (p *Progress) DecrementRequests(count int64) {
// IncrementErrorsBy increments the error counter by count.
func (p *Progress) IncrementErrorsBy(count int64) {
p.stats.IncrementCounter("errors", int(count))
}

// IncrementFailedRequestsBy increments the number of requests counter by count along with errors.
func (p *Progress) IncrementFailedRequestsBy(count int64) {
// mimic dropping by incrementing the completed requests
p.stats.IncrementCounter("requests", int(count))
p.stats.IncrementCounter("errors", int(count))
}

const bufferSize = 128

func makePrintCallback() func(stats clistats.StatisticsClient) {
func printCallback(stats clistats.StatisticsClient) {
builder := &strings.Builder{}
builder.Grow(bufferSize)

return func(stats clistats.StatisticsClient) {
builder.WriteRune('[')
startedAt, _ := stats.GetStatic("startedAt")
duration := time.Since(startedAt.(time.Time))
builder.WriteString(fmtDuration(duration))
builder.WriteRune(']')

templates, _ := stats.GetStatic("templates")
builder.WriteString(" | Templates: ")
builder.WriteString(clistats.String(templates))
hosts, _ := stats.GetStatic("hosts")
builder.WriteString(" | Hosts: ")
builder.WriteString(clistats.String(hosts))

requests, _ := stats.GetCounter("requests")
total, _ := stats.GetCounter("total")

builder.WriteString(" | RPS: ")
builder.WriteString(clistats.String(uint64(float64(requests) / duration.Seconds())))

matched, _ := stats.GetCounter("matched")

builder.WriteString(" | Matched: ")
builder.WriteString(clistats.String(matched))

errors, _ := stats.GetCounter("errors")
builder.WriteString(" | Errors: ")
builder.WriteString(clistats.String(errors))

builder.WriteString(" | Requests: ")
builder.WriteString(clistats.String(requests))
builder.WriteRune('/')
builder.WriteString(clistats.String(total))
builder.WriteRune(' ')
builder.WriteRune('(')
//nolint:gomnd // this is not a magic number
builder.WriteString(clistats.String(uint64(float64(requests) / float64(total) * 100.0)))
builder.WriteRune('%')
builder.WriteRune(')')
builder.WriteRune('\n')

gologger.Print().Msgf("%s", builder.String())
builder.Reset()
}
builder.WriteRune('[')
startedAt, _ := stats.GetStatic("startedAt")
duration := time.Since(startedAt.(time.Time))
builder.WriteString(fmtDuration(duration))
builder.WriteRune(']')

templates, _ := stats.GetStatic("templates")
builder.WriteString(" | Templates: ")
builder.WriteString(clistats.String(templates))
hosts, _ := stats.GetStatic("hosts")
builder.WriteString(" | Hosts: ")
builder.WriteString(clistats.String(hosts))

requests, _ := stats.GetCounter("requests")
total, _ := stats.GetCounter("total")

builder.WriteString(" | RPS: ")
builder.WriteString(clistats.String(uint64(float64(requests) / duration.Seconds())))

matched, _ := stats.GetCounter("matched")

builder.WriteString(" | Matched: ")
builder.WriteString(clistats.String(matched))

errors, _ := stats.GetCounter("errors")
builder.WriteString(" | Errors: ")
builder.WriteString(clistats.String(errors))

builder.WriteString(" | Requests: ")
builder.WriteString(clistats.String(requests))
builder.WriteRune('/')
builder.WriteString(clistats.String(total))
builder.WriteRune(' ')
builder.WriteRune('(')
//nolint:gomnd // this is not a magic number
builder.WriteString(clistats.String(uint64(float64(requests) / float64(total) * 100.0)))
builder.WriteRune('%')
builder.WriteRune(')')
builder.WriteRune('\n')

gologger.Print().Msgf("%s", builder.String())
}

// getMetrics returns a map of important metrics for client
Expand Down Expand Up @@ -194,6 +192,8 @@ func fmtDuration(d time.Duration) string {
// Stop stops the progress bar execution
func (p *Progress) Stop() {
if p.active {
// Print one final summary
printCallback(p.stats)
if err := p.stats.Stop(); err != nil {
gologger.Warning().Msgf("Couldn't stop statistics: %s", err)
}
Expand Down
4 changes: 2 additions & 2 deletions v2/pkg/protocols/dns/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (r *Request) ExecuteWithResults(input string, metadata, previous output.Int
compiledRequest, err := r.Make(domain)
if err != nil {
r.options.Output.Request(r.options.TemplateID, domain, "dns", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could not build request")
}

Expand All @@ -38,7 +38,7 @@ func (r *Request) ExecuteWithResults(input string, metadata, previous output.Int
resp, err := r.dnsClient.Do(compiledRequest)
if err != nil {
r.options.Output.Request(r.options.TemplateID, domain, "dns", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could not send dns request")
}
r.options.Progress.IncrementRequests()
Expand Down
2 changes: 1 addition & 1 deletion v2/pkg/protocols/file/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (r *Request) ExecuteWithResults(input string, metadata, previous output.Int
wg.Wait()
if err != nil {
r.options.Output.Request(r.options.TemplateID, input, "file", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could not send file request")
}
r.options.Progress.IncrementRequests()
Expand Down
6 changes: 3 additions & 3 deletions v2/pkg/protocols/headless/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ func (r *Request) ExecuteWithResults(input string, metadata, previous output.Int
instance, err := r.options.Browser.NewInstance()
if err != nil {
r.options.Output.Request(r.options.TemplateID, input, "headless", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could get html element")
}
defer instance.Close()

parsed, err := url.Parse(input)
if err != nil {
r.options.Output.Request(r.options.TemplateID, input, "headless", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could get html element")
}
out, page, err := instance.Run(parsed, r.Steps, time.Duration(r.options.Options.PageTimeout)*time.Second)
if err != nil {
r.options.Output.Request(r.options.TemplateID, input, "headless", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could get html element")
}
defer page.Close()
Expand Down
13 changes: 5 additions & 8 deletions v2/pkg/protocols/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ func (r *Request) executeRaceRequest(reqURL string, previous output.InternalEven
mutex.Lock()
if err != nil {
requestErr = multierr.Append(requestErr, err)
r.options.Progress.DecrementRequests(1)
}
mutex.Unlock()
}(requests[i])
Expand All @@ -95,7 +94,7 @@ func (r *Request) executeParallelHTTP(reqURL string, dynamicValues, previous out
break
}
if err != nil {
r.options.Progress.DecrementRequests(int64(generator.Total()))
r.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
return err
}
swg.Add()
Expand All @@ -107,7 +106,6 @@ func (r *Request) executeParallelHTTP(reqURL string, dynamicValues, previous out
mutex.Lock()
if err != nil {
requestErr = multierr.Append(requestErr, err)
r.options.Progress.DecrementRequests(1)
}
mutex.Unlock()
}(request)
Expand Down Expand Up @@ -154,7 +152,7 @@ func (r *Request) executeTurboHTTP(reqURL string, dynamicValues, previous output
break
}
if err != nil {
r.options.Progress.DecrementRequests(int64(generator.Total()))
r.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
return err
}
request.pipelinedClient = pipeclient
Expand All @@ -167,7 +165,6 @@ func (r *Request) executeTurboHTTP(reqURL string, dynamicValues, previous output
mutex.Lock()
if err != nil {
requestErr = multierr.Append(requestErr, err)
r.options.Progress.DecrementRequests(1)
}
mutex.Unlock()
}(request)
Expand Down Expand Up @@ -203,7 +200,7 @@ func (r *Request) ExecuteWithResults(reqURL string, dynamicValues, previous outp
break
}
if err != nil {
r.options.Progress.DecrementRequests(int64(generator.Total()))
r.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
return err
}

Expand All @@ -223,7 +220,7 @@ func (r *Request) ExecuteWithResults(reqURL string, dynamicValues, previous outp
r.options.Progress.IncrementRequests()

if request.original.options.Options.StopAtFirstMatch && gotOutput {
r.options.Progress.DecrementRequests(int64(generator.Total()))
r.options.Progress.IncrementErrorsBy(int64(generator.Total()))
break
}
}
Expand Down Expand Up @@ -300,7 +297,7 @@ func (r *Request) executeRequest(reqURL string, request *generatedRequest, previ
resp.Body.Close()
}
r.options.Output.Request(r.options.TemplateID, reqURL, "http", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementErrorsBy(1)
return err
}
defer func() {
Expand Down
14 changes: 7 additions & 7 deletions v2/pkg/protocols/network/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (r *Request) ExecuteWithResults(input string, metadata, previous output.Int
address, err := getAddress(input)
if err != nil {
r.options.Output.Request(r.options.TemplateID, input, "network", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could not get address from url")
}

Expand All @@ -50,7 +50,7 @@ func (r *Request) executeAddress(actualAddress, address, input string, shouldUse
if !strings.Contains(actualAddress, ":") {
err := errors.New("no port provided in network protocol request")
r.options.Output.Request(r.options.TemplateID, address, "network", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return err
}

Expand All @@ -71,7 +71,7 @@ func (r *Request) executeAddress(actualAddress, address, input string, shouldUse
}
if err != nil {
r.options.Output.Request(r.options.TemplateID, address, "network", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could not connect to server request")
}
defer conn.Close()
Expand All @@ -92,7 +92,7 @@ func (r *Request) executeAddress(actualAddress, address, input string, shouldUse
}
if err != nil {
r.options.Output.Request(r.options.TemplateID, address, "network", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could not write request to server")
}
reqBuilder.Grow(len(input.Data))
Expand All @@ -101,7 +101,7 @@ func (r *Request) executeAddress(actualAddress, address, input string, shouldUse
_, err = conn.Write(data)
if err != nil {
r.options.Output.Request(r.options.TemplateID, address, "network", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could not write request to server")
}

Expand All @@ -117,7 +117,7 @@ func (r *Request) executeAddress(actualAddress, address, input string, shouldUse
}
if err != nil {
r.options.Output.Request(r.options.TemplateID, address, "network", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could not write request to server")
}

Expand All @@ -137,7 +137,7 @@ func (r *Request) executeAddress(actualAddress, address, input string, shouldUse
n, err := conn.Read(final)
if err != nil && err != io.EOF {
r.options.Output.Request(r.options.TemplateID, address, "network", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could not read from server")
}
responseBuilder.Write(final[:n])
Expand Down
2 changes: 1 addition & 1 deletion v2/pkg/protocols/offlinehttp/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (r *Request) ExecuteWithResults(input string, metadata, previous output.Int
wg.Wait()
if err != nil {
r.options.Output.Request(r.options.TemplateID, input, "file", err)
r.options.Progress.DecrementRequests(1)
r.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, "could not send file request")
}
r.options.Progress.IncrementRequests()
Expand Down
2 changes: 1 addition & 1 deletion v2/pkg/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type Options struct {
Retries int
// Rate-Limit is the maximum number of requests per specified target
RateLimit int
//`PageTimeout is the maximum time to wait for a page in seconds
// PageTimeout is the maximum time to wait for a page in seconds
PageTimeout int
// OfflineHTTP is a flag that specific offline processing of http response
// using same matchers/extractors from http protocol without the need
Expand Down

0 comments on commit 0c5be83

Please sign in to comment.