diff --git a/protocol/driverprotocol.go b/protocol/driverprotocol.go index e1ea3e9c..594944dc 100644 --- a/protocol/driverprotocol.go +++ b/protocol/driverprotocol.go @@ -110,6 +110,7 @@ func NewMemCopyH2DReq( ) *MemCopyH2DReq { req := new(MemCopyH2DReq) req.ID = sim.GetIDGenerator().Generate() + req.MsgMeta.TrafficBytes = len(srcBuffer) req.SendTime = time req.Src = src req.Dst = dst @@ -131,7 +132,7 @@ func (m *MemCopyD2HReq) Meta() *sim.MsgMeta { return &m.MsgMeta } -// NewMemCopyD2HReq created a new MemCopyH2DReq +// NewMemCopyD2HReq created a new MemCopyD2HReq func NewMemCopyD2HReq( time sim.VTimeInSec, src, dst sim.Port, @@ -140,6 +141,7 @@ func NewMemCopyD2HReq( ) *MemCopyD2HReq { req := new(MemCopyD2HReq) req.ID = sim.GetIDGenerator().Generate() + req.MsgMeta.TrafficBytes = len(dstBuffer) req.SendTime = time req.Src = src req.Dst = dst diff --git a/timing/cp/dma.go b/timing/cp/dma.go index d1d321a7..454b1d1f 100644 --- a/timing/cp/dma.go +++ b/timing/cp/dma.go @@ -10,6 +10,54 @@ import ( "github.com/sarchlab/mgpusim/v3/protocol" ) +// A RequestCollection contains a single MemCopy Msg and the IDs of the Read/Write +// requests that correspond to it, as well as the number of remaining requests +type RequestCollection struct { + superiorRequest sim.Msg + subordinateRequestIDs []string + subordinateCount int +} + +// removeIDIfExists reduces the subordinate count if a specific ID is present in the list +// of subordinate IDs, returning true if it was and false if it was not +func (rqC *RequestCollection) decrementCountIfExists(id string) bool { + for _, str := range rqC.subordinateRequestIDs { + if id == str { + rqC.subordinateCount -= 1 + return true + } + } + return false +} + +// isFinished returns true if the subordinate count is zero (i.e. the superior request is finished processing) +func (rqC *RequestCollection) isFinished() bool { + return rqC.subordinateCount == 0 +} + +func (rqC *RequestCollection) getSuperior() sim.Msg { + return rqC.superiorRequest +} + +func (rqC *RequestCollection) getSuperiorID() string { + return rqC.superiorRequest.Meta().ID +} + +// appendSubordinateID adds a message ID to the list and increases the count +func (rqC *RequestCollection) appendSubordinateID(id string) { + rqC.subordinateRequestIDs = append(rqC.subordinateRequestIDs, id) + rqC.subordinateCount += 1 +} + +func NewRequestCollection( + superiorRequest sim.Msg, +) *RequestCollection { + rqC := new(RequestCollection) + rqC.superiorRequest = superiorRequest + rqC.subordinateCount = 0 + return rqC +} + // A DMAEngine is responsible for accessing data that does not belongs to // the GPU that the DMAEngine works in. type DMAEngine struct { @@ -19,7 +67,10 @@ type DMAEngine struct { localDataSource mem.LowModuleFinder - processingReq sim.Msg + processingReqs []*RequestCollection + + processingReq sim.Msg + maxRequestCount uint64 toSendToMem []sim.Msg toSendToCP []sim.Msg @@ -92,15 +143,28 @@ func (dma *DMAEngine) processDataReadyRsp( req := dma.removeReqFromPendingReqList(rsp.RespondTo).(*mem.ReadReq) tracing.TraceReqFinalize(req, dma) - processing := dma.processingReq.(*protocol.MemCopyD2HReq) + found := false + result := &RequestCollection{} + for _, rc := range dma.processingReqs { + if rc.decrementCountIfExists(req.Meta().ID) { + result = rc + found = true + } + } + + if !found { + panic("couldn't find requestcollection") + } + + processing := result.getSuperior().(*protocol.MemCopyD2HReq) offset := req.Address - processing.SrcAddress copy(processing.DstBuffer[offset:], rsp.Data) // fmt.Printf("Dma DataReady %x, %v\n", req.Address, rsp.Data) - if len(dma.pendingReqs) == 0 { - tracing.TraceReqComplete(dma.processingReq, dma) - dma.processingReq = nil + if result.isFinished() { + tracing.TraceReqComplete(processing, dma) + dma.removeReqFromProcessingReqList(processing.Meta().ID) rsp := sim.GeneralRspBuilder{}. WithDst(processing.Src). @@ -119,10 +183,23 @@ func (dma *DMAEngine) processDoneRsp( r := dma.removeReqFromPendingReqList(rsp.RespondTo) tracing.TraceReqFinalize(r, dma) - processing := dma.processingReq.(*protocol.MemCopyH2DReq) - if len(dma.pendingReqs) == 0 { - tracing.TraceReqComplete(dma.processingReq, dma) - dma.processingReq = nil + found := false + result := &RequestCollection{} + for _, rc := range dma.processingReqs { + if rc.decrementCountIfExists(r.Meta().ID) { + result = rc + found = true + } + } + + if !found { + panic("couldn't find requestcollection") + } + + if result.isFinished() { + processing := result.getSuperior().(*protocol.MemCopyH2DReq) + tracing.TraceReqComplete(processing, dma) + dma.removeReqFromProcessingReqList(processing.Meta().ID) rsp := sim.GeneralRspBuilder{}. WithDst(processing.Src). @@ -153,8 +230,25 @@ func (dma *DMAEngine) removeReqFromPendingReqList(id string) sim.Msg { return reqToRet } +func (dma *DMAEngine) removeReqFromProcessingReqList(id string) { + found := false + newList := make([]*RequestCollection, 0, len(dma.processingReqs)-1) + for _, r := range dma.processingReqs { + if r.getSuperiorID() == id { + found = true + } else { + newList = append(newList, r) + } + } + dma.processingReqs = newList + + if !found { + panic("not found") + } +} + func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool { - if dma.processingReq != nil { + if uint64(len(dma.processingReqs)) >= dma.maxRequestCount { return false } @@ -164,12 +258,14 @@ func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool { } tracing.TraceReqReceive(req, dma) - dma.processingReq = req + rqC := NewRequestCollection(req) + + dma.processingReqs = append(dma.processingReqs, rqC) switch req := req.(type) { case *protocol.MemCopyH2DReq: - dma.parseMemCopyH2D(now, req) + dma.parseMemCopyH2D(now, req, rqC) case *protocol.MemCopyD2HReq: - dma.parseMemCopyD2H(now, req) + dma.parseMemCopyD2H(now, req, rqC) default: log.Panicf("cannot process request of type %s", reflect.TypeOf(req)) } @@ -180,6 +276,7 @@ func (dma *DMAEngine) parseFromCP(now sim.VTimeInSec) bool { func (dma *DMAEngine) parseMemCopyH2D( now sim.VTimeInSec, req *protocol.MemCopyH2DReq, + rqC *RequestCollection, ) { offset := uint64(0) lengthLeft := uint64(len(req.SrcBuffer)) @@ -205,9 +302,10 @@ func (dma *DMAEngine) parseMemCopyH2D( Build() dma.toSendToMem = append(dma.toSendToMem, reqToBottom) dma.pendingReqs = append(dma.pendingReqs, reqToBottom) + rqC.appendSubordinateID(reqToBottom.Meta().ID) tracing.TraceReqInitiate(reqToBottom, dma, - tracing.MsgIDAtReceiver(dma.processingReq, dma)) + tracing.MsgIDAtReceiver(req, dma)) addr += length lengthLeft -= length @@ -218,6 +316,7 @@ func (dma *DMAEngine) parseMemCopyH2D( func (dma *DMAEngine) parseMemCopyD2H( now sim.VTimeInSec, req *protocol.MemCopyD2HReq, + rqC *RequestCollection, ) { offset := uint64(0) lengthLeft := uint64(len(req.DstBuffer)) @@ -243,9 +342,10 @@ func (dma *DMAEngine) parseMemCopyD2H( Build() dma.toSendToMem = append(dma.toSendToMem, reqToBottom) dma.pendingReqs = append(dma.pendingReqs, reqToBottom) + rqC.appendSubordinateID(reqToBottom.Meta().ID) tracing.TraceReqInitiate(reqToBottom, dma, - tracing.MsgIDAtReceiver(dma.processingReq, dma)) + tracing.MsgIDAtReceiver(req, dma)) addr += length lengthLeft -= length @@ -267,6 +367,8 @@ func NewDMAEngine( dma.Log2AccessSize = 6 dma.localDataSource = localDataSource + dma.maxRequestCount = 4 + dma.ToCP = sim.NewLimitNumMsgPort(dma, 40960000, name+".ToCP") dma.ToMem = sim.NewLimitNumMsgPort(dma, 64, name+".ToMem") diff --git a/timing/cp/dma_test.go b/timing/cp/dma_test.go index 39d7c94c..b374d74b 100644 --- a/timing/cp/dma_test.go +++ b/timing/cp/dma_test.go @@ -35,10 +35,14 @@ var _ = Describe("DMAEngine", func() { mockCtrl.Finish() }) - It("should stall if dma is processing another request", func() { - srcBuf := make([]byte, 128) - req := protocol.NewMemCopyH2DReq(5, nil, toCP, srcBuf, 20) - dmaEngine.processingReq = req + It("should stall if dma is processing max request number", func() { + for i := 0; i < int(dmaEngine.maxRequestCount); i++ { + srcBuf := make([]byte, 128) + req := protocol.NewMemCopyH2DReq(5, nil, toCP, srcBuf, uint64(20+128*i)) + rqC := NewRequestCollection(req) + + dmaEngine.processingReqs = append(dmaEngine.processingReqs, rqC) + } madeProgress := dmaEngine.parseFromCP(6) @@ -54,7 +58,7 @@ var _ = Describe("DMAEngine", func() { madeProgress := dmaEngine.parseFromCP(6) - Expect(dmaEngine.processingReq).To(BeIdenticalTo(req)) + Expect(dmaEngine.processingReqs[0].superiorRequest).To(BeIdenticalTo(req)) Expect(dmaEngine.toSendToMem).To(HaveLen(3)) Expect(dmaEngine.toSendToMem[0].(*mem.WriteReq).Address). To(Equal(uint64(20))) @@ -74,7 +78,7 @@ var _ = Describe("DMAEngine", func() { madeProgress := dmaEngine.parseFromCP(6) - Expect(dmaEngine.processingReq).To(BeIdenticalTo(req)) + Expect(dmaEngine.processingReqs[0].superiorRequest).To(BeIdenticalTo(req)) Expect(dmaEngine.toSendToMem).To(HaveLen(3)) Expect(dmaEngine.toSendToMem[0].(*mem.ReadReq).Address). To(Equal(uint64(20))) @@ -89,7 +93,8 @@ var _ = Describe("DMAEngine", func() { It("should parse DataReady from mem", func() { dstBuf := make([]byte, 128) req := protocol.NewMemCopyD2HReq(5, nil, toCP, 20, dstBuf) - dmaEngine.processingReq = req + rqC := NewRequestCollection(req) + dmaEngine.processingReqs = append(dmaEngine.processingReqs, rqC) reqToBottom1 := mem.ReadReqBuilder{}. WithSendTime(6). @@ -110,8 +115,11 @@ var _ = Describe("DMAEngine", func() { WithByteSize(64). Build() dmaEngine.pendingReqs = append(dmaEngine.pendingReqs, reqToBottom1) + rqC.appendSubordinateID(reqToBottom1.Meta().ID) dmaEngine.pendingReqs = append(dmaEngine.pendingReqs, reqToBottom2) + rqC.appendSubordinateID(reqToBottom2.Meta().ID) dmaEngine.pendingReqs = append(dmaEngine.pendingReqs, reqToBottom3) + rqC.appendSubordinateID(reqToBottom3.Meta().ID) dataReady := mem.DataReadyRspBuilder{}. WithSendTime(7). @@ -132,7 +140,8 @@ var _ = Describe("DMAEngine", func() { madeProgress := dmaEngine.parseFromMem(10) Expect(madeProgress).To(BeTrue()) - Expect(dmaEngine.processingReq).To(BeIdenticalTo(req)) + Expect(dmaEngine.processingReqs[0].superiorRequest).To(BeIdenticalTo(req)) + Expect(dmaEngine.processingReqs[0]).To(BeIdenticalTo(rqC)) Expect(dmaEngine.pendingReqs).NotTo(ContainElement(reqToBottom2)) Expect(dmaEngine.pendingReqs).To(ContainElement(reqToBottom1)) Expect(dmaEngine.pendingReqs).To(ContainElement(reqToBottom3)) @@ -142,7 +151,8 @@ var _ = Describe("DMAEngine", func() { It("should respond MemCopyD2H", func() { dstBuf := make([]byte, 128) req := protocol.NewMemCopyD2HReq(5, nil, toCP, 20, dstBuf) - dmaEngine.processingReq = req + rqC := NewRequestCollection(req) + dmaEngine.processingReqs = append(dmaEngine.processingReqs, rqC) reqToBottom2 := mem.ReadReqBuilder{}. WithSendTime(6). @@ -151,6 +161,7 @@ var _ = Describe("DMAEngine", func() { WithByteSize(64). Build() dmaEngine.pendingReqs = append(dmaEngine.pendingReqs, reqToBottom2) + rqC.appendSubordinateID(reqToBottom2.Meta().ID) dataReady := mem.DataReadyRspBuilder{}. WithSendTime(7). @@ -172,7 +183,7 @@ var _ = Describe("DMAEngine", func() { madeProgress := dmaEngine.parseFromMem(10) Expect(madeProgress).To(BeTrue()) - Expect(dmaEngine.processingReq).To(BeNil()) + Expect(dmaEngine.processingReqs).To(BeEmpty()) Expect(dmaEngine.pendingReqs).NotTo(ContainElement(reqToBottom2)) Expect(dstBuf[44:108]).To(Equal(dataReady.Data)) Expect(dmaEngine.toSendToCP[0].(*sim.GeneralRsp).OriginalReq). @@ -182,7 +193,8 @@ var _ = Describe("DMAEngine", func() { It("should parse Done from mem", func() { srcBuf := make([]byte, 128) req := protocol.NewMemCopyH2DReq(5, nil, toCP, srcBuf, 20) - dmaEngine.processingReq = req + rqC := NewRequestCollection(req) + dmaEngine.processingReqs = append(dmaEngine.processingReqs, rqC) reqToBottom1 := mem.WriteReqBuilder{}. WithSendTime(6). @@ -201,8 +213,11 @@ var _ = Describe("DMAEngine", func() { Build() dmaEngine.pendingReqs = append(dmaEngine.pendingReqs, reqToBottom1) + rqC.appendSubordinateID(reqToBottom1.Meta().ID) dmaEngine.pendingReqs = append(dmaEngine.pendingReqs, reqToBottom2) + rqC.appendSubordinateID(reqToBottom2.Meta().ID) dmaEngine.pendingReqs = append(dmaEngine.pendingReqs, reqToBottom3) + rqC.appendSubordinateID(reqToBottom3.Meta().ID) done := mem.WriteDoneRspBuilder{}. WithSendTime(7). @@ -215,7 +230,8 @@ var _ = Describe("DMAEngine", func() { madeProgress := dmaEngine.parseFromMem(10) Expect(madeProgress).To(BeTrue()) - Expect(dmaEngine.processingReq).To(BeIdenticalTo(req)) + Expect(dmaEngine.processingReqs[0].superiorRequest).To(BeIdenticalTo(req)) + Expect(dmaEngine.processingReqs[0]).To(BeIdenticalTo(rqC)) Expect(dmaEngine.pendingReqs).NotTo(ContainElement(reqToBottom2)) Expect(dmaEngine.pendingReqs).To(ContainElement(reqToBottom1)) Expect(dmaEngine.pendingReqs).To(ContainElement(reqToBottom3)) @@ -224,7 +240,8 @@ var _ = Describe("DMAEngine", func() { It("should send MemCopyH2D to top", func() { srcBuf := make([]byte, 128) req := protocol.NewMemCopyH2DReq(5, nil, toCP, srcBuf, 20) - dmaEngine.processingReq = req + rqC := NewRequestCollection(req) + dmaEngine.processingReqs = append(dmaEngine.processingReqs, rqC) reqToBottom2 := mem.WriteReqBuilder{}. WithSendTime(6). @@ -232,6 +249,7 @@ var _ = Describe("DMAEngine", func() { WithAddress(64). Build() dmaEngine.pendingReqs = append(dmaEngine.pendingReqs, reqToBottom2) + rqC.appendSubordinateID(reqToBottom2.Meta().ID) done := mem.WriteDoneRspBuilder{}. WithSendTime(7). @@ -244,7 +262,7 @@ var _ = Describe("DMAEngine", func() { madeProgress := dmaEngine.parseFromMem(10) Expect(madeProgress).To(BeTrue()) - Expect(dmaEngine.processingReq).To(BeNil()) + Expect(dmaEngine.processingReqs).To(BeEmpty()) Expect(dmaEngine.pendingReqs).NotTo(ContainElement(reqToBottom2)) Expect(dmaEngine.toSendToCP[0].(*sim.GeneralRsp).OriginalReq). To(BeIdenticalTo(req))