Skip to content

Commit

Permalink
fix compile errors
Browse files Browse the repository at this point in the history
  • Loading branch information
MengyangHe1 committed Aug 20, 2024
1 parent a5c06a0 commit 3a489a0
Show file tree
Hide file tree
Showing 32 changed files with 772 additions and 897 deletions.
151 changes: 65 additions & 86 deletions driver/driver.go

Large diffs are not rendered by default.

60 changes: 25 additions & 35 deletions driver/memorycopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,25 @@ type defaultMemoryCopyMiddleware struct {
}

func (m *defaultMemoryCopyMiddleware) ProcessCommand(
now sim.VTimeInSec,
cmd Command,
queue *CommandQueue,
) (processed bool) {
switch cmd := cmd.(type) {
case *MemCopyH2DCommand:
return m.processMemCopyH2DCommand(now, cmd, queue)
return m.processMemCopyH2DCommand(cmd, queue)
case *MemCopyD2HCommand:
return m.processMemCopyD2HCommand(now, cmd, queue)
return m.processMemCopyD2HCommand(cmd, queue)
}

return false
}

func (m *defaultMemoryCopyMiddleware) processMemCopyH2DCommand(
now sim.VTimeInSec,
cmd *MemCopyH2DCommand,
queue *CommandQueue,
) bool {
if m.needFlushing(queue.Context, cmd.Dst, uint64(binary.Size(cmd.Src))) {
m.sendFlushRequest(now, cmd)
m.sendFlushRequest(cmd)
}

buffer := bytes.NewBuffer(nil)
Expand All @@ -68,7 +66,7 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyH2DCommand(
}

gpuID := m.driver.memAllocator.GetDeviceIDByPAddr(pAddr)
req := protocol.NewMemCopyH2DReq(now,
req := protocol.NewMemCopyH2DReq(
m.driver.gpuPort, m.driver.GPUs[gpuID-1],
rawBytes[offset:offset+sizeToCopy],
pAddr)
Expand All @@ -80,7 +78,7 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyH2DCommand(
addr += sizeToCopy
offset += sizeToCopy

m.driver.logTaskToGPUInitiate(now, cmd, req)
m.driver.logTaskToGPUInitiate(cmd, req)
}

m.cyclesLeft = m.cyclesPerH2D
Expand All @@ -91,12 +89,11 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyH2DCommand(
}

func (m *defaultMemoryCopyMiddleware) processMemCopyD2HCommand(
now sim.VTimeInSec,
cmd *MemCopyD2HCommand,
queue *CommandQueue,
) bool {
if m.needFlushing(queue.Context, cmd.Src, uint64(binary.Size(cmd.Dst))) {
m.sendFlushRequest(now, cmd)
m.sendFlushRequest(cmd)
queue.Context.removeFreedBuffers()
}

Expand All @@ -119,7 +116,7 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyD2HCommand(
}

gpuID := m.driver.memAllocator.GetDeviceIDByPAddr(pAddr)
req := protocol.NewMemCopyD2HReq(now,
req := protocol.NewMemCopyD2HReq(
m.driver.gpuPort, m.driver.GPUs[gpuID-1],
pAddr, cmd.RawData[offset:offset+sizeToCopy])
cmd.Reqs = append(cmd.Reqs, req)
Expand All @@ -130,7 +127,7 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyD2HCommand(
addr += sizeToCopy
offset += sizeToCopy

m.driver.logTaskToGPUInitiate(now, cmd, req)
m.driver.logTaskToGPUInitiate(cmd, req)
}

m.cyclesLeft = m.cyclesPerD2H
Expand Down Expand Up @@ -174,21 +171,18 @@ func memRangeOverlap(
}

func (m *defaultMemoryCopyMiddleware) sendFlushRequest(
now sim.VTimeInSec,
cmd Command,
) {
for _, gpu := range m.driver.GPUs {
req := protocol.NewFlushReq(now, m.driver.gpuPort, gpu)
req := protocol.NewFlushReq(m.driver.gpuPort, gpu)
m.driver.requestsToSend = append(m.driver.requestsToSend, req)
cmd.AddReq(req)

m.driver.logTaskToGPUInitiate(now, cmd, req)
m.driver.logTaskToGPUInitiate(cmd, req)
}
}

func (m *defaultMemoryCopyMiddleware) Tick(
now sim.VTimeInSec,
) (madeProgress bool) {
func (m *defaultMemoryCopyMiddleware) Tick() (madeProgress bool) {
madeProgress = false

if m.cyclesLeft > 0 {
Expand All @@ -201,44 +195,42 @@ func (m *defaultMemoryCopyMiddleware) Tick(
madeProgress = true
}

req := m.driver.gpuPort.Peek()
req := m.driver.gpuPort.PeekIncoming()
if req == nil {
return madeProgress
}

switch req := req.(type) {
case *sim.GeneralRsp:
madeProgress = m.processGeneralRsp(now, req)
madeProgress = m.processGeneralRsp(req)
}

return madeProgress
}

func (m *defaultMemoryCopyMiddleware) processGeneralRsp(
now sim.VTimeInSec,
rsp *sim.GeneralRsp,
) bool {
originalReq := rsp.OriginalReq

switch originalReq := originalReq.(type) {
case *protocol.FlushReq:
return m.processFlushReturn(now, originalReq)
return m.processFlushReturn(originalReq)
case *protocol.MemCopyH2DReq:
return m.processMemCopyH2DReturn(now, originalReq)
return m.processMemCopyH2DReturn(originalReq)
case *protocol.MemCopyD2HReq:
return m.processMemCopyD2HReturn(now, originalReq)
return m.processMemCopyD2HReturn(originalReq)
}

return false
}

func (m *defaultMemoryCopyMiddleware) processMemCopyH2DReturn(
now sim.VTimeInSec,
req *protocol.MemCopyH2DReq,
) bool {
m.driver.gpuPort.Retrieve(now)
m.driver.gpuPort.RetrieveIncoming()

m.driver.logTaskToGPUClear(now, req)
m.driver.logTaskToGPUClear(req)

cmd, cmdQueue := m.driver.findCommandByReq(req)

Expand All @@ -255,19 +247,18 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyH2DReturn(
cmdQueue.IsRunning = false
cmdQueue.Dequeue()

m.driver.logCmdComplete(cmd, now)
m.driver.logCmdComplete(cmd)
}

return true
}

func (m *defaultMemoryCopyMiddleware) processMemCopyD2HReturn(
now sim.VTimeInSec,
req *protocol.MemCopyD2HReq,
) bool {
m.driver.gpuPort.Retrieve(now)
m.driver.gpuPort.RetrieveIncoming()

m.driver.logTaskToGPUClear(now, req)
m.driver.logTaskToGPUClear(req)

cmd, cmdQueue := m.driver.findCommandByReq(req)

Expand All @@ -284,25 +275,24 @@ func (m *defaultMemoryCopyMiddleware) processMemCopyD2HReturn(

cmdQueue.Dequeue()

m.driver.logCmdComplete(copyCmd, now)
m.driver.logCmdComplete(copyCmd)
}

return true
}

func (m *defaultMemoryCopyMiddleware) processFlushReturn(
now sim.VTimeInSec,
req *protocol.FlushReq,
) bool {
m.driver.gpuPort.Retrieve(now)
m.driver.gpuPort.RetrieveIncoming()

m.driver.logTaskToGPUClear(now, req)
m.driver.logTaskToGPUClear(req)

cmd, _ := m.driver.findCommandByReq(req)

cmd.RemoveReq(req)

m.driver.logTaskToGPUClear(now, req)
m.driver.logTaskToGPUClear(req)

return true
}
13 changes: 3 additions & 10 deletions driver/memorycopyglobalstorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package driver
import (
"bytes"
"encoding/binary"

"github.com/sarchlab/akita/v4/sim"
)

// defaultMemoryCopyMiddleware handles memory copy commands and related
Expand All @@ -14,22 +12,20 @@ type globalStorageMemoryCopyMiddleware struct {
}

func (m *globalStorageMemoryCopyMiddleware) ProcessCommand(
now sim.VTimeInSec,
cmd Command,
queue *CommandQueue,
) (processed bool) {
switch cmd := cmd.(type) {
case *MemCopyH2DCommand:
return m.processMemCopyH2DCommand(now, cmd, queue)
return m.processMemCopyH2DCommand(cmd, queue)
case *MemCopyD2HCommand:
return m.processMemCopyD2HCommand(now, cmd, queue)
return m.processMemCopyD2HCommand(cmd, queue)
}

return false
}

func (m *globalStorageMemoryCopyMiddleware) processMemCopyH2DCommand(
now sim.VTimeInSec,
cmd *MemCopyH2DCommand,
queue *CommandQueue,
) bool {
Expand Down Expand Up @@ -70,7 +66,6 @@ func (m *globalStorageMemoryCopyMiddleware) processMemCopyH2DCommand(
}

func (m *globalStorageMemoryCopyMiddleware) processMemCopyD2HCommand(
now sim.VTimeInSec,
cmd *MemCopyD2HCommand,
queue *CommandQueue,
) bool {
Expand Down Expand Up @@ -111,8 +106,6 @@ func (m *globalStorageMemoryCopyMiddleware) processMemCopyD2HCommand(
return true
}

func (m *globalStorageMemoryCopyMiddleware) Tick(
now sim.VTimeInSec,
) (madeProgress bool) {
func (m *globalStorageMemoryCopyMiddleware) Tick() (madeProgress bool) {
return false
}
5 changes: 1 addition & 4 deletions driver/middleware.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package driver

import "github.com/sarchlab/akita/v4/sim"

// A Middleware is a pluggable element of the driver that can take care of the
// handling of certain types of commands and parts of the driver-GPU
// communication.
type Middleware interface {
ProcessCommand(
now sim.VTimeInSec,
cmd Command,
queue *CommandQueue,
) (processed bool)
Tick(now sim.VTimeInSec) (madeProgress bool)
Tick() (madeProgress bool)
}
14 changes: 7 additions & 7 deletions emu/computeunit.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,20 @@ func (cu *ComputeUnit) Handle(evt sim.Event) error {
}

// Tick ticks
func (cu *ComputeUnit) Tick(now sim.VTimeInSec) bool {
cu.processMapWGReq(now)
func (cu *ComputeUnit) Tick() bool {
cu.processMapWGReq()
return false
}

func (cu *ComputeUnit) processMapWGReq(now sim.VTimeInSec) {
msg := cu.ToDispatcher.Retrieve(now)
func (cu *ComputeUnit) processMapWGReq() {
msg := cu.ToDispatcher.RetrieveIncoming()
if msg == nil {
return
}

req := msg.(*protocol.MapWGReq)

now := cu.TickingComponent.TickScheduler.CurrentTime()
if cu.nextTick <= now {
cu.nextTick = sim.VTimeInSec(math.Ceil(float64(now)))
//cu.nextTick = cu.Freq.NextTick(req.RecvTime())
Expand All @@ -127,14 +128,13 @@ func (cu *ComputeUnit) runEmulation(evt *emulationEvent) error {
for len(cu.queueingWGs) > 0 {
wg := cu.queueingWGs[0]
cu.queueingWGs = cu.queueingWGs[1:]
cu.runWG(wg, evt.Time())
cu.runWG(wg)
}
return nil
}

func (cu *ComputeUnit) runWG(
req *protocol.MapWGReq,
now sim.VTimeInSec,
) error {
wg := req.WorkGroup
cu.initWfs(wg, req)
Expand All @@ -147,6 +147,7 @@ func (cu *ComputeUnit) runWG(
cu.resolveBarrier(wg)
}

now := cu.TickingComponent.TickScheduler.CurrentTime()
evt := NewWGCompleteEvent(cu.Freq.NextTick(now), cu, req)
cu.Engine.Schedule(evt)

Expand Down Expand Up @@ -386,7 +387,6 @@ func (cu *ComputeUnit) handleWGCompleteEvent(evt *WGCompleteEvent) error {
req := protocol.WGCompletionMsgBuilder{}.
WithSrc(cu.ToDispatcher).
WithDst(evt.Req.Src).
WithSendTime(evt.Time()).
WithRspTo(cu.finishedMapWGReqs).
Build()

Expand Down
Loading

0 comments on commit 3a489a0

Please sign in to comment.