Skip to content

Commit

Permalink
Additions to dma.go to handle multiple requests (#40)
Browse files Browse the repository at this point in the history
* Additions to dma.go to handle multiple requests

* Added size to PCIe message metas

* Adjustments to dma tests
  • Loading branch information
nichosta authored Mar 6, 2024
1 parent 70cf961 commit 5a9134e
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 30 deletions.
4 changes: 3 additions & 1 deletion protocol/driverprotocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func NewMemCopyH2DReq(
) *MemCopyH2DReq {
req := new(MemCopyH2DReq)
req.ID = sim.GetIDGenerator().Generate()
req.MsgMeta.TrafficBytes = len(srcBuffer)
req.SendTime = time
req.Src = src
req.Dst = dst
Expand All @@ -131,7 +132,7 @@ func (m *MemCopyD2HReq) Meta() *sim.MsgMeta {
return &m.MsgMeta
}

// NewMemCopyD2HReq created a new MemCopyH2DReq
// NewMemCopyD2HReq created a new MemCopyD2HReq
func NewMemCopyD2HReq(
time sim.VTimeInSec,
src, dst sim.Port,
Expand All @@ -140,6 +141,7 @@ func NewMemCopyD2HReq(
) *MemCopyD2HReq {
req := new(MemCopyD2HReq)
req.ID = sim.GetIDGenerator().Generate()
req.MsgMeta.TrafficBytes = len(dstBuffer)
req.SendTime = time
req.Src = src
req.Dst = dst
Expand Down
132 changes: 117 additions & 15 deletions timing/cp/dma.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,54 @@ import (
"github.com/sarchlab/mgpusim/v3/protocol"
)

// A RequestCollection contains a single MemCopy Msg and the IDs of the Read/Write
// requests that correspond to it, as well as the number of remaining requests
type RequestCollection struct {
superiorRequest sim.Msg
subordinateRequestIDs []string
subordinateCount int
}

// removeIDIfExists reduces the subordinate count if a specific ID is present in the list
// of subordinate IDs, returning true if it was and false if it was not
func (rqC *RequestCollection) decrementCountIfExists(id string) bool {
for _, str := range rqC.subordinateRequestIDs {
if id == str {
rqC.subordinateCount -= 1
return true
}
}
return false
}

// isFinished returns true if the subordinate count is zero (i.e. the superior request is finished processing)
func (rqC *RequestCollection) isFinished() bool {
return rqC.subordinateCount == 0
}

func (rqC *RequestCollection) getSuperior() sim.Msg {
return rqC.superiorRequest
}

func (rqC *RequestCollection) getSuperiorID() string {
return rqC.superiorRequest.Meta().ID
}

// appendSubordinateID adds a message ID to the list and increases the count
func (rqC *RequestCollection) appendSubordinateID(id string) {
rqC.subordinateRequestIDs = append(rqC.subordinateRequestIDs, id)
rqC.subordinateCount += 1
}

func NewRequestCollection(
superiorRequest sim.Msg,
) *RequestCollection {
rqC := new(RequestCollection)
rqC.superiorRequest = superiorRequest
rqC.subordinateCount = 0
return rqC
}

// A DMAEngine is responsible for accessing data that does not belongs to
// the GPU that the DMAEngine works in.
type DMAEngine struct {
Expand All @@ -19,7 +67,10 @@ type DMAEngine struct {

localDataSource mem.LowModuleFinder

processingReq sim.Msg
processingReqs []*RequestCollection

processingReq sim.Msg
maxRequestCount uint64

toSendToMem []sim.Msg
toSendToCP []sim.Msg
Expand Down Expand Up @@ -92,15 +143,28 @@ func (dma *DMAEngine) processDataReadyRsp(
req := dma.removeReqFromPendingReqList(rsp.RespondTo).(*mem.ReadReq)
tracing.TraceReqFinalize(req, dma)

processing := dma.processingReq.(*protocol.MemCopyD2HReq)
found := false
result := &RequestCollection{}
for _, rc := range dma.processingReqs {
if rc.decrementCountIfExists(req.Meta().ID) {
result = rc
found = true
}
}

if !found {
panic("couldn't find requestcollection")
}

processing := result.getSuperior().(*protocol.MemCopyD2HReq)

offset := req.Address - processing.SrcAddress
copy(processing.DstBuffer[offset:], rsp.Data)
// fmt.Printf("Dma DataReady %x, %v\n", req.Address, rsp.Data)

if len(dma.pendingReqs) == 0 {
tracing.TraceReqComplete(dma.processingReq, dma)
dma.processingReq = nil
if result.isFinished() {
tracing.TraceReqComplete(processing, dma)
dma.removeReqFromProcessingReqList(processing.Meta().ID)

rsp := sim.GeneralRspBuilder{}.
WithDst(processing.Src).
Expand All @@ -119,10 +183,23 @@ func (dma *DMAEngine) processDoneRsp(
r := dma.removeReqFromPendingReqList(rsp.RespondTo)
tracing.TraceReqFinalize(r, dma)

processing := dma.processingReq.(*protocol.MemCopyH2DReq)
if len(dma.pendingReqs) == 0 {
tracing.TraceReqComplete(dma.processingReq, dma)
dma.processingReq = nil
found := false
result := &RequestCollection{}
for _, rc := range dma.processingReqs {
if rc.decrementCountIfExists(r.Meta().ID) {
result = rc
found = true
}
}

if !found {
panic("couldn't find requestcollection")
}

if result.isFinished() {
processing := result.getSuperior().(*protocol.MemCopyH2DReq)
tracing.TraceReqComplete(processing, dma)
dma.removeReqFromProcessingReqList(processing.Meta().ID)

rsp := sim.GeneralRspBuilder{}.
WithDst(processing.Src).
Expand Down Expand Up @@ -153,8 +230,25 @@ func (dma *DMAEngine) removeReqFromPendingReqList(id string) sim.Msg {
return reqToRet
}

func (dma *DMAEngine) removeReqFromProcessingReqList(id string) {
found := false
newList := make([]*RequestCollection, 0, len(dma.processingReqs)-1)
for _, r := range dma.processingReqs {
if r.getSuperiorID() == id {
found = true
} else {
newList = append(newList, r)
}
}
dma.processingReqs = newList

if !found {
panic("not found")
}
}

func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool {
if dma.processingReq != nil {
if uint64(len(dma.processingReqs)) >= dma.maxRequestCount {
return false
}

Expand All @@ -164,12 +258,14 @@ func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool {
}
tracing.TraceReqReceive(req, dma)

dma.processingReq = req
rqC := NewRequestCollection(req)

dma.processingReqs = append(dma.processingReqs, rqC)
switch req := req.(type) {
case *protocol.MemCopyH2DReq:
dma.parseMemCopyH2D(now, req)
dma.parseMemCopyH2D(now, req, rqC)
case *protocol.MemCopyD2HReq:
dma.parseMemCopyD2H(now, req)
dma.parseMemCopyD2H(now, req, rqC)
default:
log.Panicf("cannot process request of type %s", reflect.TypeOf(req))
}
Expand All @@ -180,6 +276,7 @@ func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool {
func (dma *DMAEngine) parseMemCopyH2D(
now sim.VTimeInSec,
req *protocol.MemCopyH2DReq,
rqC *RequestCollection,
) {
offset := uint64(0)
lengthLeft := uint64(len(req.SrcBuffer))
Expand All @@ -205,9 +302,10 @@ func (dma *DMAEngine) parseMemCopyH2D(
Build()
dma.toSendToMem = append(dma.toSendToMem, reqToBottom)
dma.pendingReqs = append(dma.pendingReqs, reqToBottom)
rqC.appendSubordinateID(reqToBottom.Meta().ID)

tracing.TraceReqInitiate(reqToBottom, dma,
tracing.MsgIDAtReceiver(dma.processingReq, dma))
tracing.MsgIDAtReceiver(req, dma))

addr += length
lengthLeft -= length
Expand All @@ -218,6 +316,7 @@ func (dma *DMAEngine) parseMemCopyH2D(
func (dma *DMAEngine) parseMemCopyD2H(
now sim.VTimeInSec,
req *protocol.MemCopyD2HReq,
rqC *RequestCollection,
) {
offset := uint64(0)
lengthLeft := uint64(len(req.DstBuffer))
Expand All @@ -243,9 +342,10 @@ func (dma *DMAEngine) parseMemCopyD2H(
Build()
dma.toSendToMem = append(dma.toSendToMem, reqToBottom)
dma.pendingReqs = append(dma.pendingReqs, reqToBottom)
rqC.appendSubordinateID(reqToBottom.Meta().ID)

tracing.TraceReqInitiate(reqToBottom, dma,
tracing.MsgIDAtReceiver(dma.processingReq, dma))
tracing.MsgIDAtReceiver(req, dma))

addr += length
lengthLeft -= length
Expand All @@ -267,6 +367,8 @@ func NewDMAEngine(
dma.Log2AccessSize = 6
dma.localDataSource = localDataSource

dma.maxRequestCount = 4

dma.ToCP = sim.NewLimitNumMsgPort(dma, 40960000, name+".ToCP")
dma.ToMem = sim.NewLimitNumMsgPort(dma, 64, name+".ToMem")

Expand Down
Loading

0 comments on commit 5a9134e

Please sign in to comment.