Skip to content

Commit

Permalink
Prepare Graph.Function Class for incoming v1.3.1
Browse files Browse the repository at this point in the history
Also adding version automatic verification in dynamic loading.
  • Loading branch information
RockfordWei committed Sep 18, 2017
1 parent 32574eb commit c82e0c1
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 16 deletions.
45 changes: 32 additions & 13 deletions Sources/PerfectTensorFlow/APILoader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -894,18 +894,38 @@ public class TFLib {
throw Panic.DLL(reason: String(cString: dlerror()))
}//end lib
libDLL = lib
// GraphAddFunction = try LoadFunction(lib, "TF_GraphAddFunction")
// GraphToFunction = try LoadFunction(lib, "TF_GraphToFunction")
// FunctionToFunctionDef = try LoadFunction(lib, "TF_FunctionToFunctionDef")
// DeleteFunction = try LoadFunction(lib, "TF_DeleteFunction")

SessionListDevices = try LoadFunction(lib, "TF_SessionListDevices")
DeleteDeviceList = try LoadFunction(lib, "TF_DeleteDeviceList")
DeviceListCount = try LoadFunction(lib, "TF_DeviceListCount")
DeviceListName = try LoadFunction(lib, "TF_DeviceListName")
DeviceListType = try LoadFunction(lib, "TF_DeviceListType")
DeviceListMemoryBytes = try LoadFunction(lib, "TF_DeviceListMemoryBytes")
AddGradients = try LoadFunction(lib, "TF_AddGradients")
Version = try LoadFunction(lib, "TF_Version")

guard let v = Version() else {
throw Panic.DLL(reason: "Unresoved version info")
}

let ver = String(cString: v)

guard ver >= "1.1.0" else {
throw Panic.DLL(reason: "Version \(ver) is obsolete and out of support.")
}

if ver > "1.3.0" {
GraphAddFunction = try LoadFunction(lib, "TF_GraphAddFunction")
GraphToFunction = try LoadFunction(lib, "TF_GraphToFunction")
FunctionToFunctionDef = try LoadFunction(lib, "TF_FunctionToFunctionDef")
DeleteFunction = try LoadFunction(lib, "TF_DeleteFunction")
}

if ver >= "1.2.1" {
SessionListDevices = try LoadFunction(lib, "TF_SessionListDevices")
DeleteDeviceList = try LoadFunction(lib, "TF_DeleteDeviceList")
DeviceListCount = try LoadFunction(lib, "TF_DeviceListCount")
DeviceListName = try LoadFunction(lib, "TF_DeviceListName")
DeviceListType = try LoadFunction(lib, "TF_DeviceListType")
DeviceListMemoryBytes = try LoadFunction(lib, "TF_DeviceListMemoryBytes")
}

if ver >= "1.2.0" {
AddGradients = try LoadFunction(lib, "TF_AddGradients")
}

SetAttrValueProto = try LoadFunction(lib, "TF_SetAttrValueProto")
GetAllOpList = try LoadFunction(lib, "TF_GetAllOpList")
DeleteLibraryHandle = try LoadFunction(lib, "TF_DeleteLibraryHandle")
Expand Down Expand Up @@ -1021,7 +1041,6 @@ public class TFLib {
DeleteStatus = try LoadFunction(lib, "TF_DeleteStatus")
NewStatus = try LoadFunction(lib, "TF_NewStatus")
DataTypeSize = try LoadFunction(lib, "TF_DataTypeSize")
Version = try LoadFunction(lib, "TF_Version")
}//end open

/// static library closing
Expand Down
122 changes: 119 additions & 3 deletions Sources/PerfectTensorFlow/PerfectTensorFlow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1685,50 +1685,120 @@ public class TensorFlow {
return try Runner(graph: self, sessionOptions: sessionOptions, runOptions: runOptions, exportDir: exportDir, tags: tags, metaGraphDef: metaGraphDef)
}//end session

/// Generate a const operation
/// - parameters:
/// - tensor: Tensor to build in this operation
/// - type: String, default is "Const"
/// - name: String, default is "const". **CAUTION:** Naming space is not available up to 1.3.0
/// - returns: Operation
/// - throws: Panic
public func const(tensor: Tensor, `type`: String = "Const", name: String = "const") throws -> Operation {
guard let tp = tensor.type else { throw Panic.INVALID }
return try self.opBuilder(name: name, type: type)
.set(attributes: ["value": tensor, "dtype": tp]).build()
}//end const

/// Generate a placeholder tensor
/// - parameters:
/// - type: String, default is "Placeholder"
/// - name: String, default is "feed". **CAUTION:** Naming space is not available up to 1.3.0
/// - returns: Operation
/// - throws: Panic
public func placeholder(`type`: String = "Placeholder", name: String = "feed") throws -> Operation {
return try self.opBuilder(name: name, type: type)
.set(attributes: ["dtype": DataType.dtInt32]).build()
}//end Placeholder

/// Generate an integer scalar tensor
/// - parameters:
/// - v: Integer
/// - name: String, default is "scalar". **CAUTION:** Naming space is not available up to 1.3.0
/// - returns: Operation
/// - throws: Panic
public func scalar(_ v: Int, name: String = "scalar") throws -> Operation {
let x = try Tensor.Scalar(Int32(v))
return try self.const(tensor: x, name: name)
}//end ScalarConst

/// Generate an float type scalar tensor
/// - parameters:
/// - v: Float
/// - name: String, default is "scalar". **CAUTION:** Naming space is not available up to 1.3.0
/// - returns: Operation
/// - throws: Panic
public func scalar(_ v: Float, name: String = "scalar") throws -> Operation {
let x = try Tensor.Scalar(Float32(v))
return try self.const(tensor: x, name: name)
}//end ScalarConst

/// Add two inputs
/// - parameters:
/// - left: Output, left input to add
/// - right: Output, right input to add
/// - type: String, type of the operation, default is "AddN"
/// - name: String, default is "add". **CAUTION:** Naming space is not available up to 1.3.0
/// - returns: Operation
/// - throws: Panic
public func add(left: Output, right: Output, `type`: String = "AddN", name: String = "add") throws -> Operation {
return try self.opBuilder(name: name, type: type).add(inputs: [left, right]).build()
}//end Add

/// Add two operations
/// - parameters:
/// - left: Operation, left operation to add
/// - right: Operation, right operation to add
/// - name: String, default is "add". **CAUTION:** Naming space is not available up to 1.3.0
/// - returns: Operation
/// - throws: Panic
public func add(left: Operation, right: Operation, name: String = "add") throws -> Operation {
return try self.add(left: left.asOutput(0), right: right.asOutput(0), name: name)
}//end Add

/// Generate a negative operation from the current one.
/// - parameters:
/// - n: Operation, the original operation.
/// - type: String, type of the operation, default is "Neg"
/// - name: String, default is "neg". **CAUTION:** Naming space is not available up to 1.3.0
/// - returns: Operation
/// - throws: Panic
public func neg(_ n: Operation, `type`: String = "Neg", name: String = "neg") throws -> Operation {
return try self.opBuilder(name: name, type: type)
.add(input: n.asOutput(0)).build()
}//end neg

/// Compare two inputs and test if the left operand is less than the right one.
/// - parameters:
/// - left: Output, left input to compare
/// - right: Output, right input to compare
/// - type: String, type of the operation, default is "Less"
/// - name: String, default is "less_than". **CAUTION:** Naming space is not available up to 1.3.0
/// - returns: Operation
/// - throws: Panic
public func lessThan(left: Output, right: Output, `type`: String = "Less", name: String = "less_than") throws -> Operation {
return try self.opBuilder(name: name, type: type).add(input: left).add(input: right).build()
}//end LessThan

/// Setup a 2 x 2 matrix
/// - parameters:
/// - values: [Float], the matrix in row major order.
/// - name: String, name of the matrix to create
/// - returns: Operation
/// - throws: Panic
public func floatConst2x2(values: [Float], name: String) throws -> Operation {
let tensor = try Tensor.Array(dimensions: [2,2], value: values)
return try self.opBuilder(name: name, type: "Const")
.set(attributes: ["value": tensor, "dtype": DataType.dtFloat]).build()
}//end FloatConst2x2

/// Matrix Multipliction
/// - parameters:
/// - l: Operation, left matrix to multiply
/// - r: Operation, right matrix to multiply
/// - name: String, name of the matrix to create
/// - transposeA: Bool, should transpose left matrix before multiplication
/// - transposeB: Bool, should transpose right matrix before multiplication
/// - returns: Operation
/// - throws: Panic
public func matMul(l: Operation, r: Operation, name: String, transposeA: Bool = false, transposeB: Bool = false ) throws -> Operation {
var a: [String: Any] = [:]
if transposeA {
Expand Down Expand Up @@ -1798,26 +1868,72 @@ public class TensorFlow {
.add(input: in1).add(input: in2).build().output(index)
}

/// Create a function from a graph
/// - parameters:
/// - name: String, the name of the new TF_Function. Should match the operation name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct from other operation names (at least those registered in graphs where this function will be used).
/// - operations: [Operation], Array of operations to become the body of the function or null
/// - inputs: [Output], array of TF_Outputs that specify the inputs to the function
/// - outputs: [Output], array of TF_Outputs that specify the outputs of the function.
/// - outputNames: [String], The names of the function's outputs. Must either have the same length as `outputs` or be null. In the former case, the names should match the regular expression for ArgDef names - "[a-z][a-z0-9_]*". In the latter case, names for outputs will be generated automatically.
public func toFunction(_ name: String, operations: [Operation], inputs: [Output], outputs: [Output], outputNames: [String], options: OpaquePointer? = nil) throws -> Function? {
return nil
guard outputs.count == outputNames.count else {
throw Panic.FAULT(reason: "Output array elements are mismatched with names")
}
let status = try Status()
let opera:UnsafePointer<OpaquePointer?>? = operations.map { $0.operation }
.withUnsafeBufferPointer { $0.baseAddress }
let pInputs = inputs.withUnsafeBufferPointer { $0.baseAddress }
let pOutpus = outputs.withUnsafeBufferPointer { $0.baseAddress }
let pOutputNames:UnsafePointer<UnsafePointer<CChar>?>? = outputNames
.map { $0.withCString { p -> UnsafePointer<CChar> in return p } }
.withUnsafeBufferPointer { $0.baseAddress }
guard let fun = TFLib.GraphToFunction(graph, name, Int32(operations.count > 0 ? operations.count: -1), operations.count > 0 ? opera : nil, Int32(inputs.count), pInputs, Int32(outputs.count), pOutpus, pOutputNames, options, status.status),
let code = status.code, code == .OK else {
throw Panic.FAULT(reason: status.message)
}//end guard
return Function(self, reference: fun)
}

/// Function is a grouping of operations with defined inputs and outputs.
/// Once created and added to graphs, functions can be invoked by creating an
/// operation whose operation type matches the function name.
public class Function {
let g: Graph
let ref: OpaquePointer

/// constructor. DO **NOT** CALL IT DIRECTLY. Call `Graph.toFunction()` to generate function instead.
public init(_ graph: Graph, reference: OpaquePointer) {
g = graph
ref = reference
}

public func add () {
TFLib.GraphAddFunction(g.graph, ref, nil)
/// Add `function` to graph `g`. Once `function` is added to `g`,
/// it can be called by creating an operation using the function's name.
public func add () throws {
let status = try Status()
TFLib.GraphAddFunction(g.graph, ref, status.status)
guard let code = status.code, code == .OK else {
throw Panic.FAULT(reason: status.message)
}//end guard
}

/// delete function
public func delete() {
TFLib.DeleteFunction(ref)
}

/// get protocol buffer of the current function
public var buffer: Buffer? {
guard let buf = try? Buffer(), let status = try? Status() else {
return nil
}
TFLib.FunctionToFunctionDef(ref, buf.buffer, status.status)
if let code = status.code, code == .OK {
return buf
} else {
return nil
}
}
}

}//end graph
Expand Down
1 change: 1 addition & 0 deletions install.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/bin/bash
OSABR=$(echo $(uname)|tr '[:upper:]' '[:lower:]')
DWN=/tmp/libtensorflow.tgz
VERSION=`cat VERSION`
Expand Down
1 change: 1 addition & 0 deletions test.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/bin/bash
# This script is for CI Server
VERSION=`cat VERSION`
echo 'Clean Temp Files'
Expand Down

0 comments on commit c82e0c1

Please sign in to comment.