Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additions to dma.go to handle multiple requests #40

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 117 additions & 15 deletions timing/cp/dma.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,54 @@ import (
"github.com/sarchlab/mgpusim/v3/protocol"
)

// A RequestCollection contains a single MemCopy Msg and the IDs of the Read/Write
// requests that correspond to it, as well as the number of remaining requests
type RequestCollection struct {
superiorRequest sim.Msg
subordinateRequestIDs []string
subordinateCount int
}

// removeIDIfExists reduces the subordinate count if a specific ID is present in the list
// of subordinate IDs, returning true if it was and false if it was not
func (rqC *RequestCollection) decrementCountIfExists(id string) bool {
for _, str := range rqC.subordinateRequestIDs {
if id == str {
rqC.subordinateCount -= 1
return true
}
}
return false
}

// isFinished returns true if the subordinate count is zero (i.e. the superior request is finished processing)
func (rqC *RequestCollection) isFinished() bool {
return rqC.subordinateCount == 0
}

func (rqC *RequestCollection) getSuperior() sim.Msg {
return rqC.superiorRequest
}

func (rqC *RequestCollection) getSuperiorID() string {
return rqC.superiorRequest.Meta().ID
}

// appendSubordinateID adds a message ID to the list and increases the count
func (rqC *RequestCollection) appendSubordinateID(id string) {
rqC.subordinateRequestIDs = append(rqC.subordinateRequestIDs, id)
rqC.subordinateCount += 1
}

func NewRequestCollection(
superiorRequest sim.Msg,
) *RequestCollection {
rqC := new(RequestCollection)
rqC.superiorRequest = superiorRequest
rqC.subordinateCount = 0
return rqC
}

// A DMAEngine is responsible for accessing data that does not belongs to
// the GPU that the DMAEngine works in.
type DMAEngine struct {
Expand All @@ -19,7 +67,10 @@ type DMAEngine struct {

localDataSource mem.LowModuleFinder

processingReq sim.Msg
processingReqs []*RequestCollection

processingReq sim.Msg
maxRequestCount uint64

toSendToMem []sim.Msg
toSendToCP []sim.Msg
Expand Down Expand Up @@ -92,15 +143,28 @@ func (dma *DMAEngine) processDataReadyRsp(
req := dma.removeReqFromPendingReqList(rsp.RespondTo).(*mem.ReadReq)
tracing.TraceReqFinalize(req, dma)

processing := dma.processingReq.(*protocol.MemCopyD2HReq)
found := false
result := &RequestCollection{}
for _, rc := range dma.processingReqs {
if rc.decrementCountIfExists(req.Meta().ID) {
result = rc
found = true
}
}

if !found {
panic("couldn't find requestcollection")
}

processing := result.getSuperior().(*protocol.MemCopyD2HReq)

offset := req.Address - processing.SrcAddress
copy(processing.DstBuffer[offset:], rsp.Data)
// fmt.Printf("Dma DataReady %x, %v\n", req.Address, rsp.Data)

if len(dma.pendingReqs) == 0 {
tracing.TraceReqComplete(dma.processingReq, dma)
dma.processingReq = nil
if result.isFinished() {
tracing.TraceReqComplete(processing, dma)
dma.removeReqFromProcessingReqList(processing.Meta().ID)

rsp := sim.GeneralRspBuilder{}.
WithDst(processing.Src).
Expand All @@ -119,10 +183,23 @@ func (dma *DMAEngine) processDoneRsp(
r := dma.removeReqFromPendingReqList(rsp.RespondTo)
tracing.TraceReqFinalize(r, dma)

processing := dma.processingReq.(*protocol.MemCopyH2DReq)
if len(dma.pendingReqs) == 0 {
tracing.TraceReqComplete(dma.processingReq, dma)
dma.processingReq = nil
found := false
result := &RequestCollection{}
for _, rc := range dma.processingReqs {
if rc.decrementCountIfExists(r.Meta().ID) {
result = rc
found = true
}
}

if !found {
panic("couldn't find requestcollection")
}

if result.isFinished() {
processing := result.getSuperior().(*protocol.MemCopyH2DReq)
tracing.TraceReqComplete(processing, dma)
dma.removeReqFromProcessingReqList(processing.Meta().ID)

rsp := sim.GeneralRspBuilder{}.
WithDst(processing.Src).
Expand Down Expand Up @@ -153,8 +230,25 @@ func (dma *DMAEngine) removeReqFromPendingReqList(id string) sim.Msg {
return reqToRet
}

func (dma *DMAEngine) removeReqFromProcessingReqList(id string) {
found := false
newList := make([]*RequestCollection, 0, len(dma.processingReqs)-1)
for _, r := range dma.processingReqs {
if r.getSuperiorID() == id {
found = true
} else {
newList = append(newList, r)
}
}
dma.processingReqs = newList

if !found {
panic("not found")
}
}

func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool {
if dma.processingReq != nil {
if uint64(len(dma.processingReqs)) >= dma.maxRequestCount {
return false
}

Expand All @@ -164,12 +258,14 @@ func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool {
}
tracing.TraceReqReceive(req, dma)

dma.processingReq = req
rqC := NewRequestCollection(req)

dma.processingReqs = append(dma.processingReqs, rqC)
switch req := req.(type) {
case *protocol.MemCopyH2DReq:
dma.parseMemCopyH2D(now, req)
dma.parseMemCopyH2D(now, req, rqC)
case *protocol.MemCopyD2HReq:
dma.parseMemCopyD2H(now, req)
dma.parseMemCopyD2H(now, req, rqC)
default:
log.Panicf("cannot process request of type %s", reflect.TypeOf(req))
}
Expand All @@ -180,6 +276,7 @@ func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool {
func (dma *DMAEngine) parseMemCopyH2D(
now sim.VTimeInSec,
req *protocol.MemCopyH2DReq,
rqC *RequestCollection,
) {
offset := uint64(0)
lengthLeft := uint64(len(req.SrcBuffer))
Expand All @@ -205,9 +302,10 @@ func (dma *DMAEngine) parseMemCopyH2D(
Build()
dma.toSendToMem = append(dma.toSendToMem, reqToBottom)
dma.pendingReqs = append(dma.pendingReqs, reqToBottom)
rqC.appendSubordinateID(reqToBottom.Meta().ID)

tracing.TraceReqInitiate(reqToBottom, dma,
tracing.MsgIDAtReceiver(dma.processingReq, dma))
tracing.MsgIDAtReceiver(req, dma))

addr += length
lengthLeft -= length
Expand All @@ -218,6 +316,7 @@ func (dma *DMAEngine) parseMemCopyH2D(
func (dma *DMAEngine) parseMemCopyD2H(
now sim.VTimeInSec,
req *protocol.MemCopyD2HReq,
rqC *RequestCollection,
) {
offset := uint64(0)
lengthLeft := uint64(len(req.DstBuffer))
Expand All @@ -243,9 +342,10 @@ func (dma *DMAEngine) parseMemCopyD2H(
Build()
dma.toSendToMem = append(dma.toSendToMem, reqToBottom)
dma.pendingReqs = append(dma.pendingReqs, reqToBottom)
rqC.appendSubordinateID(reqToBottom.Meta().ID)

tracing.TraceReqInitiate(reqToBottom, dma,
tracing.MsgIDAtReceiver(dma.processingReq, dma))
tracing.MsgIDAtReceiver(req, dma))

addr += length
lengthLeft -= length
Expand All @@ -267,6 +367,8 @@ func NewDMAEngine(
dma.Log2AccessSize = 6
dma.localDataSource = localDataSource

dma.maxRequestCount = 4

dma.ToCP = sim.NewLimitNumMsgPort(dma, 40960000, name+".ToCP")
dma.ToMem = sim.NewLimitNumMsgPort(dma, 64, name+".ToMem")

Expand Down
Loading