Skip to content

Commit

Permalink
Update backend test (#990)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jun 30, 2024
1 parent 5460f33 commit 9477659
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 64 deletions.
18 changes: 12 additions & 6 deletions source/browser.js
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ host.BrowserHost = class {
const [url] = this._meta.file;
if (this._view.accept(url)) {
const identifier = Array.isArray(this._meta.identifier) && this._meta.identifier.length === 1 ? this._meta.identifier[0] : null;
const status = await this._openModel(this._url(url), identifier || null);
const name = this._meta.name || null;
const status = await this._openModel(this._url(url), identifier || null, name);
if (status === '') {
return;
}
Expand Down Expand Up @@ -413,7 +414,7 @@ host.BrowserHost = class {
return `${location.protocol}//${location.host}${pathname}${file}`;
}

async _openModel(url, identifier) {
async _openModel(url, identifier, name) {
url = url.startsWith('data:') ? url : `${url + ((/\?/).test(url) ? '&' : '?')}cb=${(new Date()).getTime()}`;
this._view.show('welcome spinner');
let context = null;
Expand All @@ -430,7 +431,7 @@ host.BrowserHost = class {
stream = await this._request(url, null, null, progress);
}
}
context = new host.BrowserHost.Context(this, url, identifier, stream);
context = new host.BrowserHost.Context(this, url, identifier, name, stream);
this._telemetry.set('session_engaged', 1);
} catch (error) {
await this._view.error(error, 'Model load request failed.');
Expand Down Expand Up @@ -474,7 +475,7 @@ host.BrowserHost = class {
const encoder = new TextEncoder();
const buffer = encoder.encode(file.content);
const stream = new base.BinaryStream(buffer);
const context = new host.BrowserHost.Context(this, '', identifier, stream);
const context = new host.BrowserHost.Context(this, '', identifier, null, stream);
await this._openContext(context);
} catch (error) {
await this._view.error(error, 'Error while loading Gist.');
Expand All @@ -487,7 +488,7 @@ host.BrowserHost = class {
try {
const model = await this._view.open(context);
if (model) {
this.document.title = context.identifier;
this.document.title = context.name || context.identifier;
return '';
}
this.document.title = '';
Expand Down Expand Up @@ -787,8 +788,9 @@ host.BrowserHost.FileStream = class {

host.BrowserHost.Context = class {

constructor(host, url, identifier, stream) {
constructor(host, url, identifier, name, stream) {
this._host = host;
this._name = name;
this._stream = stream;
if (identifier) {
this._identifier = identifier;
Expand All @@ -807,6 +809,10 @@ host.BrowserHost.Context = class {
return this._identifier;
}

get name() {
return this._name;
}

get stream() {
return this._stream;
}
Expand Down
40 changes: 20 additions & 20 deletions source/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,19 @@ def _metadata_props(self, metadata_props): # pylint: disable=missing-function-do
class _Graph:
def __init__(self, graph, metadata):
self.metadata = metadata
self.value = graph
self.arguments_index = {}
self.arguments = []
self.graph = graph
self.values_index = {}
self.values = []

def _tensor(self, tensor): # pylint: disable=unused-argument
return {}

def argument(self, name, tensor_type=None, initializer=None): # pylint: disable=missing-function-docstring
if not name in self.arguments_index:
argument = _Argument(name, tensor_type, initializer)
self.arguments_index[name] = len(self.arguments)
self.arguments.append(argument)
index = self.arguments_index[name]
def value(self, name, tensor_type=None, initializer=None): # pylint: disable=missing-function-docstring
if not name in self.values_index:
argument = _Value(name, tensor_type, initializer)
self.values_index[name] = len(self.values)
self.values.append(argument)
index = self.values_index[name]
# argument.set_initializer(initializer)
return index

Expand Down Expand Up @@ -138,17 +138,17 @@ def attribute(self, _, op_type): # pylint: disable=missing-function-docstring,to
return json_attribute

def to_json(self): # pylint: disable=missing-function-docstring
graph = self.value
graph = self.graph
json_graph = {
'nodes': [],
'inputs': [],
'outputs': [],
'arguments': []
'values': []
}
for value_info in graph.value_info:
self.argument(value_info.name)
self.value(value_info.name)
for initializer in graph.initializer:
self.argument(initializer.name, None, initializer)
self.value(initializer.name, None, initializer)
for node in graph.node:
op_type = node.op_type
json_node = {}
Expand All @@ -164,24 +164,24 @@ def to_json(self): # pylint: disable=missing-function-docstring
for value in node.input:
json_node['inputs'].append({
'name': 'X',
'arguments': [ self.argument(value) ]
'value': [ self.value(value) ]
})
json_node['outputs'] = []
for value in node.output:
json_node['outputs'].append({
'name': 'X',
'arguments': [ self.argument(value) ]
'value': [ self.value(value) ]
})
json_node['attributes'] = []
for _ in node.attribute:
json_attribute = self.attribute(_, op_type)
json_node['attributes'].append(json_attribute)
json_graph['nodes'].append(json_node)
for _ in self.arguments:
json_graph['arguments'].append(_.to_json())
for _ in self.values:
json_graph['values'].append(_.to_json())
return json_graph

class _Argument: # pylint: disable=too-few-public-methods
class _Value: # pylint: disable=too-few-public-methods
def __init__(self, name, tensor_type=None, initializer=None):
self.name = name
self.type = tensor_type
Expand All @@ -190,8 +190,8 @@ def __init__(self, name, tensor_type=None, initializer=None):
def to_json(self): # pylint: disable=missing-function-docstring
target = {}
target['name'] = self.name
if self.initializer:
target['initializer'] = {}
# if self.initializer:
# target['initializer'] = {}
return target

class _Metadata: # pylint: disable=too-few-public-methods
Expand Down
44 changes: 23 additions & 21 deletions source/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def to_json(self): # pylint: disable=missing-function-docstring,too-many-locals,
import torch # pylint: disable=import-outside-toplevel,import-error
graph = self.value
json_graph = {
'arguments': [],
'values': [],
'nodes': [],
'inputs': [],
'outputs': []
Expand All @@ -73,59 +73,61 @@ def constant_value(node):
selector = node.kindOf('value')
return getattr(node, selector)('value')
return None
arguments_map = {}
values_index = {}
def argument(value):
if not value in arguments_map:
json_argument = {}
json_argument['name'] = str(value.unique())
if not value in values_index:
json_value = {}
json_value['name'] = str(value.unique())
node = value.node()
if node.kind() == "prim::GetAttr":
tensor, name = self._getattr(node)
if tensor is not None and len(name) > 0 and \
isinstance(tensor, torch.Tensor):
json_argument['name'] = name
json_argument['initializer'] = {}
json_tensor_shape = {
'dimensions': list(tensor.shape)
}
json_argument['type'] = {
tensor_type = {
'dataType': data_type_map[tensor.dtype],
'shape': json_tensor_shape
}
json_value['name'] = name
json_value['type'] = tensor_type
json_value['initializer'] = { 'type': tensor_type }
elif node.kind() == "prim::Constant":
tensor = constant_value(node)
if tensor and isinstance(tensor, torch.Tensor):
json_argument['initializer'] = {}
json_tensor_shape = {
'dimensions': list(tensor.shape)
}
json_argument['type'] = {
tensor_type = {
'dataType': data_type_map[tensor.dtype],
'shape': json_tensor_shape
}
json_value['type'] = tensor_type
json_value['initializer'] = { 'type': tensor_type }
elif value.isCompleteTensor():
json_tensor_shape = {
'dimensions': value.type().sizes()
}
json_argument['type'] = {
json_value['type'] = {
'dataType': data_type_map[value.type().dtype()],
'shape': json_tensor_shape
}
arguments = json_graph['arguments']
arguments_map[value] = len(arguments)
arguments.append(json_argument)
return arguments_map[value]
values = json_graph['values']
values_index[value] = len(values)
values.append(json_value)
return values_index[value]

for value in graph.inputs():
if len(value.uses()) != 0 and value.type().kind() != 'ClassType':
json_graph['inputs'].append({
'name': value.debugName(),
'arguments': [ argument(value) ]
'value': [ argument(value) ]
})
for value in graph.outputs():
json_graph['outputs'].append({
'name': value.debugName(),
'arguments': [ argument(value) ]
'value': [ argument(value) ]
})
constants = {}
for node in graph.nodes():
Expand Down Expand Up @@ -163,7 +165,7 @@ def create_node(node):
if torch.is_tensor(value):
json_node['inputs'].append({
'name': name,
'arguments': []
'value': []
})
else:
json_node['attributes'].append(json_attribute)
Expand All @@ -177,7 +179,7 @@ def create_node(node):
if parameter_type == 'Tensor' or value.type().kind() == 'TensorType':
json_node['inputs'].append({
'name': parameter_name,
'arguments': [ argument(value) ]
'value': [ argument(value) ]
})
else:
json_attribute = {
Expand All @@ -203,15 +205,15 @@ def create_node(node):
continue
json_node['inputs'].append({
'name': parameter_name,
'arguments': [ argument(value) ]
'value': [ argument(value) ]
})

for i, value in enumerate(node.outputs()):
parameter = schema['outputs'][i] if schema and i < len(schema['outputs']) else None
name = parameter['name'] if parameter and 'name' in parameter else 'output'
json_node['outputs'].append({
'name': name,
'arguments': [ argument(value) ]
'value': [ argument(value) ]
})

for node in graph.nodes():
Expand Down
32 changes: 18 additions & 14 deletions source/server.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,25 @@ message.Graph = class {
this.inputs = [];
this.outputs = [];
this.nodes = [];
const args = data.arguments ? data.arguments.map((argument) => new message.Value(argument)) : [];
for (const parameter of data.inputs || []) {
parameter.arguments = parameter.arguments.map((index) => args[index]).filter((argument) => !argument.initializer);
if (parameter.arguments.filter((argument) => !argument.initializer).length > 0) {
this.inputs.push(new message.Argument(parameter));
const values = data.values ? data.values.map((value) => new message.Value(value)) : [];
for (const argument of data.inputs || []) {
argument.value = argument.value.map((index) => values[index]).filter((argument) => !argument.initializer);
if (argument.value.filter((argument) => !argument.initializer).length > 0) {
this.inputs.push(new message.Argument(argument));
}
}
for (const parameter of data.outputs || []) {
parameter.arguments = parameter.arguments.map((index) => args[index]);
if (parameter.arguments.filter((argument) => !argument.initializer).length > 0) {
this.outputs.push(new message.Argument(parameter));
for (const argument of data.outputs || []) {
argument.value = argument.value.map((index) => values[index]);
if (argument.value.filter((argument) => !argument.initializer).length > 0) {
this.outputs.push(new message.Argument(argument));
}
}
for (const node of data.nodes || []) {
for (const parameter of node.inputs || []) {
parameter.arguments = parameter.arguments.map((index) => args[index]);
for (const argument of node.inputs || []) {
argument.value = argument.value.map((index) => values[index]);
}
for (const parameter of node.outputs || []) {
parameter.arguments = parameter.arguments.map((index) => args[index]);
for (const argument of node.outputs || []) {
argument.value = argument.value.map((index) => values[index]);
}
this.nodes.push(new message.Node(node));
}
Expand All @@ -76,7 +76,7 @@ message.Argument = class {

constructor(data) {
this.name = data.name || '';
this.value = (data.arguments || []);
this.value = data.value || [];
this.type = data.type || '';
}
};
Expand Down Expand Up @@ -129,6 +129,10 @@ message.TensorShape = class {
};

message.Tensor = class {

constructor(data) {
this.type = new message.TensorType(data.type);
}
};

message.Error = class extends Error {
Expand Down
10 changes: 7 additions & 3 deletions source/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ class _ContentProvider: # pylint: disable=too-few-public-methods
base_dir = ''
base = ''
identifier = ''
def __init__(self, data, path, file):
def __init__(self, data, path, file, name):
self.data = data if data else bytearray()
self.identifier = os.path.basename(file) if file else ''
self.name = name
if path:
self.dir = os.path.dirname(path) if os.path.dirname(path) else '.'
self.base = os.path.basename(path)
Expand Down Expand Up @@ -95,6 +96,9 @@ def do_GET(self): # pylint: disable=invalid-name
base = self.content.base
if base:
meta.append('<meta name="file" content="/data/' + base + '">')
name = self.content.name
if name:
meta.append('<meta name="name" content="' + name + '">')
identifier = self.content.identifier
if identifier:
meta.append('<meta name="identifier" content="' + identifier + '">')
Expand Down Expand Up @@ -281,14 +285,14 @@ def serve(file, data, address=None, browse=False, verbosity=1):
if not data and file and not os.path.exists(file):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file)

content = _ContentProvider(data, file, file)
content = _ContentProvider(data, file, file, file)

if data and not isinstance(data, bytearray) and isinstance(data.__class__, type):
_log(verbosity > 1, 'Experimental\n')
model = _open(data)
if model:
text = json.dumps(model.to_json(), indent=4, ensure_ascii=False)
content = _ContentProvider(text.encode('utf-8'), 'model.netron', file)
content = _ContentProvider(text.encode('utf-8'), 'model.netron', None, file)

address = _make_address(address)
if isinstance(address[1], int) and address[1] != 0:
Expand Down

0 comments on commit 9477659

Please sign in to comment.