Skip to content

Commit

Permalink
Relu emu slow (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhuibao authored Jan 30, 2024
1 parent 99049f3 commit 70cf961
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 28 deletions.
6 changes: 3 additions & 3 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@
"mode": "debug",
"program": "${workspaceFolder}/samples/kmeans",
"args": [
"-timing",
// "-timing",
"-points=1024",
"-features=32",
"-clusters=5",
"-max-iter=5",
"-report-all",
"-max-iter=4",
// "-report-all",
],
},
{
Expand Down
23 changes: 21 additions & 2 deletions emu/computeunit.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type ComputeUnit struct {
GlobalMemStorage *mem.Storage

ToDispatcher sim.Port

finishedMapWGReqs []string
}

// ControlPort returns the port that can receive controlling messages from the
Expand Down Expand Up @@ -366,15 +368,32 @@ func (cu *ComputeUnit) resolveBarrier(wg *kernels.WorkGroup) {

func (cu *ComputeUnit) handleWGCompleteEvent(evt *WGCompleteEvent) error {
delete(cu.wfs, evt.Req.WorkGroup)
found := false
for _, r := range cu.finishedMapWGReqs {
if r == evt.Req.ID {
found = true
break
}
}
if !found {
cu.finishedMapWGReqs = append(cu.finishedMapWGReqs, evt.Req.ID)
}

if len(cu.wfs) != 0 {
return nil
}

req := protocol.WGCompletionMsgBuilder{}.
WithRspTo(evt.Req.ID).
WithSrc(cu.ToDispatcher).
WithDst(evt.Req.Src).
WithSendTime(evt.Time()).
WithRspTo(cu.finishedMapWGReqs).
Build()

err := cu.ToDispatcher.Send(req)
if err != nil {
if err == nil {
cu.finishedMapWGReqs = nil
} else {
newEvent := NewWGCompleteEvent(cu.Freq.NextTick(evt.Time()),
cu, evt.Req)
cu.Engine.Schedule(newEvent)
Expand Down
6 changes: 3 additions & 3 deletions protocol/cuprotocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func (b MapWGReqBuilder) Build() *MapWGReq {
// execution
type WGCompletionMsg struct {
sim.MsgMeta
RspTo string
RspTo []string
}

// Meta returns the meta data associated with the MapWGReq.
Expand All @@ -287,7 +287,7 @@ func (r *WGCompletionMsg) Meta() *sim.MsgMeta {
type WGCompletionMsgBuilder struct {
sendTime sim.VTimeInSec
src, dst sim.Port
rspTo string
rspTo []string
}

// WithSendTime sets the send time.
Expand Down Expand Up @@ -316,7 +316,7 @@ func (b WGCompletionMsgBuilder) WithDst(

// WithRspTo sets rspTo
func (b WGCompletionMsgBuilder) WithRspTo(
rspTo string,
rspTo []string,
) WGCompletionMsgBuilder {
b.rspTo = rspTo
return b
Expand Down
44 changes: 28 additions & 16 deletions timing/cp/internal/dispatching/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dispatching

import (
"fmt"
"log"

"github.com/sarchlab/akita/v3/monitoring"
"github.com/sarchlab/akita/v3/sim"
Expand Down Expand Up @@ -120,28 +121,39 @@ func (d *DispatcherImpl) processMessagesFromCU(now sim.VTimeInSec) bool {

switch msg := msg.(type) {
case *protocol.WGCompletionMsg:
location, ok := d.inflightWGs[msg.RspTo]
if !ok {
return false
count := 0
for _, rspToID := range msg.RspTo {
_, ok := d.inflightWGs[rspToID]
if ok {
count += 1
}
}

d.alg.FreeResources(location)
delete(d.inflightWGs, msg.RspTo)
d.numCompletedWGs++
if d.numCompletedWGs == d.alg.NumWG() {
d.cycleLeft = d.constantKernelOverhead
if count == 0 {
return false
} else if count < len(msg.RspTo) {
log.Panic("In emulation all finished WGs from more than one dispatcher")
}

d.dispatchingPort.Retrieve(now)

originalReq := d.originalReqs[msg.RspTo]
delete(d.originalReqs, msg.RspTo)
tracing.TraceReqFinalize(originalReq, d)

if d.progressBar != nil {
d.progressBar.MoveInProgressToFinished(1)
for _, rspToID := range msg.RspTo {
location := d.inflightWGs[rspToID]
d.alg.FreeResources(location)
delete(d.inflightWGs, rspToID)
d.numCompletedWGs++
if d.numCompletedWGs == d.alg.NumWG() {
d.cycleLeft = d.constantKernelOverhead
}

originalReq := d.originalReqs[rspToID]
delete(d.originalReqs, rspToID)
tracing.TraceReqFinalize(originalReq, d)

if d.progressBar != nil {
d.progressBar.MoveInProgressToFinished(1)
}
}

d.dispatchingPort.Retrieve(now)
return true
}

Expand Down
6 changes: 3 additions & 3 deletions timing/cp/internal/dispatching/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ var _ = Describe("Dispatcher", func() {
dispatcher.inflightWGs[mapWGReq.ID] = location
dispatcher.originalReqs[mapWGReq.ID] = mapWGReq

wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: mapWGReq.ID}
wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: []string{mapWGReq.ID}}

dispatcher.numDispatchedWGs = 64
dispatcher.numCompletedWGs = 48
Expand Down Expand Up @@ -197,7 +197,7 @@ var _ = Describe("Dispatcher", func() {
dispatcher.inflightWGs[mapWGReq.ID] = location
dispatcher.originalReqs[mapWGReq.ID] = mapWGReq

wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: mapWGReq.ID}
wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: []string{mapWGReq.ID}}

dispatcher.numDispatchedWGs = 64
dispatcher.numCompletedWGs = 63
Expand Down Expand Up @@ -227,7 +227,7 @@ var _ = Describe("Dispatcher", func() {
mapWGReq := protocol.MapWGReqBuilder{}.Build()
// dispatcher.inflightWGs[mapWGReq.ID] = location

wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: mapWGReq.ID}
wgCompletionMsg := &protocol.WGCompletionMsg{RspTo: []string{mapWGReq.ID}}

dispatcher.numDispatchedWGs = 64
dispatcher.numCompletedWGs = 48
Expand Down
2 changes: 1 addition & 1 deletion timing/cu/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func (s *SchedulerImpl) sendWGCompletionMessage(
WithSendTime(now).
WithSrc(s.cu.ToACE).
WithDst(dispatcher).
WithRspTo(mapReq.ID).
WithRspTo([]string{mapReq.ID}).
Build()

err := s.cu.ToACE.Send(msg)
Expand Down

0 comments on commit 70cf961

Please sign in to comment.