Skip to content

Commit

Permalink
Fixing constant(index::) automatically set to zero
Browse files Browse the repository at this point in the history
Also adding memory leak tests.
  • Loading branch information
RockfordWei committed Jun 20, 2017
1 parent cd1cc90 commit aa2bea8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
11 changes: 7 additions & 4 deletions Sources/PerfectTensorFlow/PerfectTensorFlow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,12 @@ public class TensorFlow {


deinit {
if autoDestroy {
if autoDestroy { //, let mem = TFLib.TensorData(tensor) {
// let sz = TFLib.TensorByteSize(tensor)
TFLib.DeleteTensor(tensor)
// if sz > 0 {
// mem.deallocate(bytes:sz, alignedTo: 0)
//}//end if
}//end if
}

Expand Down Expand Up @@ -1530,14 +1534,13 @@ public class TensorFlow {
}//end if

let pOutputs = UnsafeMutablePointer<Output>.allocate(capacity: count)
defer { pOutputs.deallocate(capacity: count) }
TFLib.GraphImportGraphDefWithReturnOutputs(graph, buf.buffer, options.options, pOutputs, Int32(count), status.status)
guard status.code == .OK else {
pOutputs.deallocate(capacity: count)
throw Panic.FAULT(reason: status.message)
}//end guard
let buffered = UnsafeMutableBufferPointer<Output>(start: pOutputs, count: count)
let outputs = Array(buffered)
pOutputs.deallocate(capacity: count)
return outputs
}//end func

Expand Down Expand Up @@ -1671,7 +1674,7 @@ public class TensorFlow {

public func constant<T>(name: String, value: T, index:Int = 0) throws -> Output {
let t = try Tensor.Scalar(value)
return try self.const(tensor: t, name: name).asOutput(0)
return try self.const(tensor: t, name: name).asOutput(index)
}

public func constantArray<T>(name: String, value: [T], index:Int = 0) throws -> Output {
Expand Down
39 changes: 23 additions & 16 deletions Tests/PerfectTensorFlowTests/PerfectTensorFlowTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -150,49 +150,57 @@ class PerfectTensorFlowTests: XCTestCase {
let img = try LabelImage()
guard let eight = Data.Load("/tmp/testdata/8.jpg") else
{ throw TF.Panic.FAULT(reason: "hand write file 8.jpg not found")}
let x = try img.match(image: eight)
XCTAssertGreaterThan(x, 0)
print(x)
for _ in 0 ... 30 {
#if os(Linux)
let x = try img.match(image: eight)
XCTAssertEqual(x, 536)
#else
autoreleasepool(invoking: {
do {
let x = try img.match(image: eight)
XCTAssertEqual(x, 536)
}catch {
XCTFail("label loop: \(error)")
}
})
#endif
}
}catch {
XCTFail("label: \(error)")
}
}

class LabelImage {
let g: TF.Graph
let def: TF.GraphDef

public init(_ modelPath:String = "/tmp/testdata/tensorflow_inception_graph.pb") throws {
g = try TF.Graph()
guard let bytes = Data.Load(modelPath) else { throw TF.Panic.INVALID }
let def = try TF.GraphDef(serializedData: bytes)
try g.import(definition: def)
let inp = try g.searchOperation(forName: "input")
print(inp)
def = try TF.GraphDef(serializedData: bytes)
}

public func match(image: Data) throws -> Int {
let normalized = try constructAndExecuteGraphToNormalizeImage(imageBytes: image)
let possibilities = try executeInceptionGraph(image: normalized)
let g = try TF.Graph()
try g.import(definition: def)
let normalized = try constructAndExecuteGraphToNormalizeImage(g, imageBytes: image)
let possibilities = try executeInceptionGraph(g, image: normalized)
guard let m = possibilities.max(), let i = possibilities.index(of: m) else {
throw TF.Panic.INVALID
}//end guard
print("max", m)
return i
}

private func executeInceptionGraph(image: TF.Tensor) throws -> [Float] {
private func executeInceptionGraph(_ g: TF.Graph, image: TF.Tensor) throws -> [Float] {
let results = try g.runner().feed("input", tensor: image).fetch("output").run()
guard results.count > 0 else { throw TF.Panic.INVALID }
let result = results[0]
guard result.dimensionCount == 2 else { throw TF.Panic.INVALID }
let shape = result.dim
guard shape[0] == 1 else { throw TF.Panic.INVALID }
print(shape[1])
let res: [Float] = try result.asArray()
return res
}//end exec

public func constructAndExecuteGraphToNormalizeImage(imageBytes: Data) throws -> TF.Tensor {
public func constructAndExecuteGraphToNormalizeImage(_ g: TF.Graph, imageBytes: Data) throws -> TF.Tensor{
let H:Int32 = 224
let W:Int32 = 224
let mean:Float = 117
Expand Down Expand Up @@ -1057,7 +1065,6 @@ class PerfectTensorFlowTests: XCTestCase {
func testOpList() {
do {
let oplist = try TF.OperationList()
print(oplist.operations.count)
XCTAssertGreaterThan(oplist.operations.count, 0)
}catch {
XCTFail("OpList: \(error)")
Expand Down

0 comments on commit aa2bea8

Please sign in to comment.