diff --git a/timing/cp/dma.go b/timing/cp/dma.go index d1d321a7..454b1d1f 100644 --- a/timing/cp/dma.go +++ b/timing/cp/dma.go @@ -10,6 +10,54 @@ import ( "github.com/sarchlab/mgpusim/v3/protocol" ) +// A RequestCollection contains a single MemCopy Msg and the IDs of the Read/Write +// requests that correspond to it, as well as the number of remaining requests +type RequestCollection struct { + superiorRequest sim.Msg + subordinateRequestIDs []string + subordinateCount int +} + +// removeIDIfExists reduces the subordinate count if a specific ID is present in the list +// of subordinate IDs, returning true if it was and false if it was not +func (rqC *RequestCollection) decrementCountIfExists(id string) bool { + for _, str := range rqC.subordinateRequestIDs { + if id == str { + rqC.subordinateCount -= 1 + return true + } + } + return false +} + +// isFinished returns true if the subordinate count is zero (i.e. the superior request is finished processing) +func (rqC *RequestCollection) isFinished() bool { + return rqC.subordinateCount == 0 +} + +func (rqC *RequestCollection) getSuperior() sim.Msg { + return rqC.superiorRequest +} + +func (rqC *RequestCollection) getSuperiorID() string { + return rqC.superiorRequest.Meta().ID +} + +// appendSubordinateID adds a message ID to the list and increases the count +func (rqC *RequestCollection) appendSubordinateID(id string) { + rqC.subordinateRequestIDs = append(rqC.subordinateRequestIDs, id) + rqC.subordinateCount += 1 +} + +func NewRequestCollection( + superiorRequest sim.Msg, +) *RequestCollection { + rqC := new(RequestCollection) + rqC.superiorRequest = superiorRequest + rqC.subordinateCount = 0 + return rqC +} + // A DMAEngine is responsible for accessing data that does not belongs to // the GPU that the DMAEngine works in. type DMAEngine struct { @@ -19,7 +67,10 @@ type DMAEngine struct { localDataSource mem.LowModuleFinder - processingReq sim.Msg + processingReqs []*RequestCollection + + processingReq sim.Msg + maxRequestCount uint64 toSendToMem []sim.Msg toSendToCP []sim.Msg @@ -92,15 +143,28 @@ func (dma *DMAEngine) processDataReadyRsp( req := dma.removeReqFromPendingReqList(rsp.RespondTo).(*mem.ReadReq) tracing.TraceReqFinalize(req, dma) - processing := dma.processingReq.(*protocol.MemCopyD2HReq) + found := false + result := &RequestCollection{} + for _, rc := range dma.processingReqs { + if rc.decrementCountIfExists(req.Meta().ID) { + result = rc + found = true + } + } + + if !found { + panic("couldn't find requestcollection") + } + + processing := result.getSuperior().(*protocol.MemCopyD2HReq) offset := req.Address - processing.SrcAddress copy(processing.DstBuffer[offset:], rsp.Data) // fmt.Printf("Dma DataReady %x, %v\n", req.Address, rsp.Data) - if len(dma.pendingReqs) == 0 { - tracing.TraceReqComplete(dma.processingReq, dma) - dma.processingReq = nil + if result.isFinished() { + tracing.TraceReqComplete(processing, dma) + dma.removeReqFromProcessingReqList(processing.Meta().ID) rsp := sim.GeneralRspBuilder{}. WithDst(processing.Src). @@ -119,10 +183,23 @@ func (dma *DMAEngine) processDoneRsp( r := dma.removeReqFromPendingReqList(rsp.RespondTo) tracing.TraceReqFinalize(r, dma) - processing := dma.processingReq.(*protocol.MemCopyH2DReq) - if len(dma.pendingReqs) == 0 { - tracing.TraceReqComplete(dma.processingReq, dma) - dma.processingReq = nil + found := false + result := &RequestCollection{} + for _, rc := range dma.processingReqs { + if rc.decrementCountIfExists(r.Meta().ID) { + result = rc + found = true + } + } + + if !found { + panic("couldn't find requestcollection") + } + + if result.isFinished() { + processing := result.getSuperior().(*protocol.MemCopyH2DReq) + tracing.TraceReqComplete(processing, dma) + dma.removeReqFromProcessingReqList(processing.Meta().ID) rsp := sim.GeneralRspBuilder{}. WithDst(processing.Src). @@ -153,8 +230,25 @@ func (dma *DMAEngine) removeReqFromPendingReqList(id string) sim.Msg { return reqToRet } +func (dma *DMAEngine) removeReqFromProcessingReqList(id string) { + found := false + newList := make([]*RequestCollection, 0, len(dma.processingReqs)-1) + for _, r := range dma.processingReqs { + if r.getSuperiorID() == id { + found = true + } else { + newList = append(newList, r) + } + } + dma.processingReqs = newList + + if !found { + panic("not found") + } +} + func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool { - if dma.processingReq != nil { + if uint64(len(dma.processingReqs)) >= dma.maxRequestCount { return false } @@ -164,12 +258,14 @@ func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool { } tracing.TraceReqReceive(req, dma) - dma.processingReq = req + rqC := NewRequestCollection(req) + + dma.processingReqs = append(dma.processingReqs, rqC) switch req := req.(type) { case *protocol.MemCopyH2DReq: - dma.parseMemCopyH2D(now, req) + dma.parseMemCopyH2D(now, req, rqC) case *protocol.MemCopyD2HReq: - dma.parseMemCopyD2H(now, req) + dma.parseMemCopyD2H(now, req, rqC) default: log.Panicf("cannot process request of type %s", reflect.TypeOf(req)) } @@ -180,6 +276,7 @@ func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool { func (dma *DMAEngine) parseMemCopyH2D( now sim.VTimeInSec, req *protocol.MemCopyH2DReq, + rqC *RequestCollection, ) { offset := uint64(0) lengthLeft := uint64(len(req.SrcBuffer)) @@ -205,9 +302,10 @@ func (dma *DMAEngine) parseMemCopyH2D( Build() dma.toSendToMem = append(dma.toSendToMem, reqToBottom) dma.pendingReqs = append(dma.pendingReqs, reqToBottom) + rqC.appendSubordinateID(reqToBottom.Meta().ID) tracing.TraceReqInitiate(reqToBottom, dma, - tracing.MsgIDAtReceiver(dma.processingReq, dma)) + tracing.MsgIDAtReceiver(req, dma)) addr += length lengthLeft -= length @@ -218,6 +316,7 @@ func (dma *DMAEngine) parseMemCopyH2D( func (dma *DMAEngine) parseMemCopyD2H( now sim.VTimeInSec, req *protocol.MemCopyD2HReq, + rqC *RequestCollection, ) { offset := uint64(0) lengthLeft := uint64(len(req.DstBuffer)) @@ -243,9 +342,10 @@ func (dma *DMAEngine) parseMemCopyD2H( Build() dma.toSendToMem = append(dma.toSendToMem, reqToBottom) dma.pendingReqs = append(dma.pendingReqs, reqToBottom) + rqC.appendSubordinateID(reqToBottom.Meta().ID) tracing.TraceReqInitiate(reqToBottom, dma, - tracing.MsgIDAtReceiver(dma.processingReq, dma)) + tracing.MsgIDAtReceiver(req, dma)) addr += length lengthLeft -= length @@ -267,6 +367,8 @@ func NewDMAEngine( dma.Log2AccessSize = 6 dma.localDataSource = localDataSource + dma.maxRequestCount = 4 + dma.ToCP = sim.NewLimitNumMsgPort(dma, 40960000, name+".ToCP") dma.ToMem = sim.NewLimitNumMsgPort(dma, 64, name+".ToMem")