Skip to content

Commit

Permalink
Update pytorch.js
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jun 8, 2024
1 parent 97a00fc commit f93e11e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 52 deletions.
103 changes: 57 additions & 46 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -1695,44 +1695,57 @@ python.Execution = class {
this.registerFunction('builtins.__import__', (name, globals, locals, fromlist, level) => {
return execution.__import__(name, globals, locals, fromlist, level);
});
this.registerFunction('builtins.bool', (value) => {
if (value) {
if (value.__bool__) {
return value.__bool__();
}
if (value.__len__) {
return value.__len__() > 0;
this.registerType('builtins.bool', class extends Boolean {
constructor(value) {
if (value && value.__bool__) {
value = value.__bool__();
} else if (value && value.__len__) {
value = value.__len__() > 0;
} else {
value = value ? true : false;
}
super(value);
}
return false;
});
this.registerFunction('builtins.int', (value) => {
if (value) {
if (value.__int__) {
return value.__int__();
}
if (Number.isInteger(value)) {
return value;
this.registerType('builtins.int', class extends Number {
constructor(value) {
if (value && value.__int__) {
value = value.__int__();
} else if (!Number.isInteger(value)) {
value = NaN;
}
super(value);
}
return NaN;
});
this.registerFunction('builtins.float', (value) => {
if (value) {
if (value.__float__) {
return value.__float__();
this.registerType('builtins.float', class extends Number {
constructor(value) {
if (value && value.__float__) {
value = value.__float__();
} else if (Number(value) !== value) {
value = NaN;
}
if (Number(value) === value) {
return value;
super(value);
}
});
this.registerType('builtins.long', class extends Number {
constructor(value) {
if (value && value.__int__) {
value = value.__int__();
} else if (!Number.isInteger(value)) {
value = NaN;
}
super(value);
}
return NaN;
});
this.registerFunction('builtins.str', (value) => {
if (value && value.__str__) {
return value.__str__();
this.registerType('builtins.str', class extends String {
constructor(value) {
if (value && value.__str__) {
value = value.__str__();
} else if (typeof value !== 'string') {
value = JSON.stringify(value);
}
super(value);
}
return JSON.stringify(value);
});
this.registerType('builtins.complex', class {
constructor(real, imaginary) {
Expand Down Expand Up @@ -1763,7 +1776,6 @@ python.Execution = class {
this.registerType('builtins.Exception', class extends builtins.BaseException {});
this.registerType('builtins.AttributeError', class extends builtins.Exception {});
this.registerType('builtins.SyntaxError', class extends builtins.Exception {});
this.registerFunction('builtins.long', this.builtins.int);
this.registerFunction('builtins.print', () => {});
this.registerFunction('builtins.unicode');
builtins.Ellipsis = new builtins.ellipsis();
Expand Down Expand Up @@ -3613,8 +3625,7 @@ python.Execution = class {
for (const name of ['__builtin__', 'types']) {
const module = self.register(name);
for (const [name, obj] of Object.entries(module)) {
if (obj.__module__ === 'builtins' &&
obj.__class__ === builtins.type) {
if (obj.__module__ === 'builtins' && obj.__class__ === builtins.type) {
_dill._reverse_typemap.set(name, obj);
}
}
Expand Down Expand Up @@ -4971,8 +4982,8 @@ python.Execution = class {
return tensor;
});
this.registerFunction('torch.add', (left, right) => {
if (typeof left === 'number' && typeof right === 'number') {
return left * right;
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
return left + right;
}
if (Array.isArray(left) && Array.isArray(right)) {
return left.concat(right);
Expand Down Expand Up @@ -5039,7 +5050,7 @@ python.Execution = class {
if (typeof left === 'string' && typeof right === 'string') {
return left === right;
}
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
if (isNaN(left) && isNaN(right)) {
return true;
}
Expand Down Expand Up @@ -5076,7 +5087,7 @@ python.Execution = class {
}).join('');
});
this.registerFunction('torch.gt', (left, right) => {
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
if (!isNaN(left) && !isNaN(right)) {
return left > right;
}
Expand All @@ -5087,7 +5098,7 @@ python.Execution = class {
throw new python.Error("Unsupported 'torch.gt' expression type.");
});
this.registerFunction('torch.ge', (left, right) => {
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
if (!isNaN(left) && !isNaN(right)) {
return left > right;
}
Expand Down Expand Up @@ -5145,7 +5156,7 @@ python.Execution = class {
return NaN;
});
this.registerFunction('torch.le', (left, right) => {
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
if (isNaN(left) || isNaN(right)) {
return false;
}
Expand Down Expand Up @@ -5352,25 +5363,25 @@ python.Execution = class {
});
this.registerFunction('torch.log10');
this.registerFunction('torch.lt', (left, right) => {
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
return left < right;
}
throw new python.Error("Unsupported 'torch.lt' expression type.");
});
this.registerFunction('torch.mul', (left, right) => {
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
return left * right;
}
if (isNaN(left) || isNaN(right)) {
return NaN;
}
if (Array.isArray(left) && left.every((value) => typeof value === 'number') && typeof right === 'number') {
if (Array.isArray(left) && left.every((value) => typeof value === 'number' || value instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
return left.map((value) => value * right);
}
throw new python.Error("Unsupported 'torch.mul' expression type.");
});
this.registerFunction('torch.div', (left, right) => {
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
return left / right;
}
if (isNaN(left) || isNaN(right)) {
Expand All @@ -5379,7 +5390,7 @@ python.Execution = class {
throw new python.Error("Unsupported 'torch.div' expression type.");
});
this.registerFunction('torch.round', (value) => {
if (typeof value === 'number') {
if (typeof value === 'number' || value instanceof Number) {
return Math.round(value);
}
if (isNaN(value)) {
Expand All @@ -5388,7 +5399,7 @@ python.Execution = class {
throw new python.Error("Unsupported 'torch.round' expression type.");
});
this.registerFunction('torch.remainder', (left, right) => {
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
return left % right;
}
if (isNaN(left) || isNaN(right)) {
Expand All @@ -5400,7 +5411,7 @@ python.Execution = class {
if (typeof left === 'boolean' && typeof right === 'boolean') {
return left !== right;
}
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
if (isNaN(left) || isNaN(right)) {
return false;
}
Expand All @@ -5424,7 +5435,7 @@ python.Execution = class {
throw new python.Error("Unsupported 'torch.neg' expression type.");
});
this.registerFunction('torch.pow', (left, right) => {
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
return Math.pow(left, right);
}
throw new python.Error("Unsupported 'torch.pow' expression type.");
Expand Down Expand Up @@ -5474,7 +5485,7 @@ python.Execution = class {
return l.slice(start, end);
});
this.registerFunction('torch.sub', (left, right) => {
if (typeof left === 'number' && typeof right === 'number') {
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
return left - right;
}
throw new python.Error("Unsupported 'torch.sub' expression type.");
Expand Down Expand Up @@ -6909,7 +6920,7 @@ python.Execution = class {
const func = name ? callTarget[name] : callTarget;
if (func.__class__ === this._builtins.type) {
if (func.prototype && func.prototype.__class__ === func) {
return Reflect.construct(func, args);
return Reflect.construct(func, callArguments);
}
const obj = Object.create(func);
obj.__class__ = func;
Expand Down
14 changes: 8 additions & 6 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -3419,40 +3419,42 @@ pytorch.Utility = class {
case 'Tensor[]':
return Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null);
case 'Scalar':
return (obj !== null && obj !== Object(obj)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0);
return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0);
case 'boolean':
return obj === true || obj === false;
case 'string':
return obj === null || typeof obj === 'string';
case 'SymInt':
case 'int64':
return Number.isInteger(obj) || typeof obj === 'bigint' || (typeof obj === 'number' && isNaN(obj));
return Number.isInteger(obj) || typeof obj === 'bigint' ||
(typeof obj === 'number' && isNaN(obj)) || (obj instanceof Number);
case 'SymInt[]':
case 'SymInt[2]':
case 'SymInt[3]':
case 'SymInt[4]':
case 'SymInt[5]':
case 'SymInt[6]':
return Array.isArray(obj) && obj.every((item) => pytorch.Utility.isType(item, 'SymInt') || item === undefined || (item.__class__ === 'number' && isNaN(item)));
case 'int64[]':
case 'int64[2]':
case 'int64[3]':
return Array.isArray(obj) && obj.every((item) => Number.isInteger(item) || (typeof item === 'number' && isNaN(item)) || item === undefined);
return Array.isArray(obj) && obj.every((item) => pytorch.Utility.isType(item, 'int64') || item === undefined || (item.__class__ === 'number' && isNaN(item)));
case 'int64[1]':
case 'SymInt[1]':
return pytorch.Utility.isType(obj, 'int64') || pytorch.Utility.isType(obj, 'int64[]');
case 'float32':
case 'float64':
return obj !== null && obj !== Object(obj);
return obj !== null && (typeof obj === 'number' || obj instanceof Number);
case 'float32[]':
return Array.isArray(obj) && obj.every((item) => typeof item === 'number' && !isNaN(item));
return Array.isArray(obj) && obj.every((item) => (typeof item === 'number' || item instanceof Number) && !isNaN(item));
case 'string[][]':
return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string'));
case 'Layout':
case 'ScalarType':
case 'MemoryFormat':
return Number.isInteger(obj) || obj === null;
case 'Dimname':
return obj === null || typeof obj === 'string';
return obj === null || (typeof obj === 'string' || obj instanceof String);
case 'Dimname[]':
return Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string');
case 'Device':
Expand Down

0 comments on commit f93e11e

Please sign in to comment.