-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
586 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,69 @@ | ||
<script> | ||
import "./app.css"; | ||
import Header from "./components/Header.svelte"; | ||
import "./app.css"; | ||
import Header from "./lib/Header.svelte"; | ||
import ImageSelector from "./lib/ImageSelector.svelte"; | ||
import { onDestroy, onMount } from "svelte"; | ||
import { loadEmbeddings, loadModels } from "./load"; | ||
import * as tf from "@tensorflow/tfjs"; | ||
import * as tfu from "./tfUtils"; | ||
class VectorQuantizer { | ||
constructor(embeddings) { | ||
this.embeddings = embeddings; | ||
this.numEmbed = this.embeddings.shape[1]; | ||
this.embedDim = this.embeddings.shape[0]; | ||
} | ||
/** | ||
* @param {tf.Tensor} x | ||
*/ | ||
predict(x) { | ||
const xShape = x.shape; | ||
const features = x.reshape([-1, this.embedDim]); | ||
// quantization step | ||
const idxs = tfu.tfDist(features, this.embeddings).argMin(1); | ||
const selectColumns = tf.oneHot(idxs, this.numEmbed); | ||
const quantized = tf.matMul( | ||
selectColumns, | ||
this.embeddings.transpose() | ||
); | ||
return quantized.reshape(xShape); | ||
} | ||
} | ||
const images = [1, 2, 3, 4, 5, 7].map((d) => `images/${d}.png`); | ||
let rawImages; | ||
let selectedImage = "images/1.png"; | ||
$: console.log(selectedImage); | ||
onMount(async () => { | ||
const embeddings = await loadEmbeddings(); | ||
tf.tidy(() => { | ||
const tensorEmbeddings = tf.tensor(embeddings); | ||
const vq = new VectorQuantizer(tensorEmbeddings); | ||
}); | ||
// tf.tidy(() => { | ||
// loadModels().then(([encoder, decoder]) => { | ||
// const input = tf.ones([1, 28, 28, 1]); | ||
// const out = encoder.predict(input); | ||
// const decInput = tf.ones([1, 7, 7, 16]); | ||
// const decOut = decoder.predict(decInput); | ||
// }); | ||
// }); | ||
}); | ||
</script> | ||
|
||
<Header /> | ||
<main> | ||
... | ||
<main class="p-5"> | ||
<div class="mb-2 flex gap-2 items-center"> | ||
<ImageSelector imageUrls={images} bind:selectedUrl={selectedImage} /> | ||
</div> | ||
</main> | ||
|
||
|
||
<style></style> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
<script> | ||
export let imageUrls = []; | ||
export let selectedUrl; | ||
const width = 40; | ||
</script> | ||
|
||
<div class="flex gap-2"> | ||
{#each imageUrls as src, i} | ||
<img | ||
{src} | ||
alt="begone yellow squigly lines" | ||
{width} | ||
class:lined={selectedUrl === src} | ||
on:click={() => { | ||
selectedUrl = src; | ||
}} | ||
/> | ||
{/each} | ||
<div | ||
class="s" | ||
style="width: {width}px; height: {width}px; background: black; opacity: {selectedUrl === | ||
'clear' | ||
? 1 | ||
: 0.4}; {selectedUrl === 'clear' | ||
? 'box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.5);' | ||
: ''};" | ||
on:click={() => (selectedUrl = "clear")} | ||
></div> | ||
</div> | ||
|
||
<style> | ||
.lined { | ||
opacity: 1; | ||
box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.5); | ||
} | ||
.s { | ||
cursor: pointer; | ||
} | ||
img { | ||
cursor: pointer; | ||
opacity: 0.4; | ||
} | ||
</style> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import * as tf from "@tensorflow/tfjs"; | ||
|
||
const filepath = "tfjs"; | ||
|
||
export async function loadModels() { | ||
const encoder = await tf.loadGraphModel(`${filepath}/encoder/model.json`); | ||
const decoder = await tf.loadGraphModel(`${filepath}/decoder/model.json`); | ||
return [encoder, decoder]; | ||
} | ||
|
||
/** | ||
* @param {string} txt | ||
*/ | ||
function parseNumpyTxt(txt) { | ||
const rows = txt.split("\n"); | ||
let result = []; | ||
for (let i = 0; i < rows.length - 1; ++i) { | ||
const r = rows[i].split(" "); | ||
const floats = r.map((d) => Number(d)); | ||
result.push(floats); | ||
} | ||
return result; | ||
} | ||
|
||
export async function loadEmbeddings(name = "embeddings_dim16_num32.txt") { | ||
const out = await (await fetch(`${filepath}/${name}`)).text(); | ||
const parsed = parseNumpyTxt(out); | ||
return parsed; | ||
} | ||
|
||
export function loadImage(url) { | ||
const img = new Image(); | ||
return new Promise((res, rej) => { | ||
img.src = url; | ||
img.onload = () => res(img); | ||
img.onerror = rej; | ||
}); | ||
} | ||
|
||
export async function loadImageFull(url) { | ||
const img = await loadImage(url); | ||
const canvas = document.createElement("canvas"); | ||
const ctx = canvas.getContext("2d"); | ||
ctx.drawImage(img, 0, 0); | ||
const d = ctx.getImageData(0, 0, img.width, img.height).data; | ||
img.remove(); | ||
canvas.remove(); | ||
return d; | ||
} | ||
|
||
function toGrey(d) { | ||
const result = new Uint8ClampedArray(d.length / 4); | ||
for (let i = 0, j = 0; i < d.length; i += 4, j++) { | ||
result[j] = d[i]; | ||
} | ||
return result; | ||
} | ||
|
||
export async function fetchAllImages(urls) { | ||
let result = {}; | ||
for (let i = 0; i < urls.length; i++) { | ||
const url = urls[i]; | ||
const d = await loadImageFull(url); | ||
const g = toGrey(d); | ||
const f32 = new Float32Array(g.length).map((_, i) => g[i] / 255); | ||
result[url] = f32; | ||
} | ||
return result; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import * as tf from "@tensorflow/tfjs"; | ||
|
||
/** | ||
* @param {tf.Tensor} a | ||
* @param {tf.Tensor} b | ||
*/ | ||
export function tfDist(a, b) { | ||
const a_2 = tf.sum(tf.square(a), 1, true); | ||
const b_2 = tf.sum(tf.square(b), 0, true); | ||
const twoab = tf.mul(2, tf.matMul(a, b)); | ||
return a_2.add(b_2).sub(twoab); | ||
} | ||
|
||
/** | ||
* @param {number[][]} a | ||
* @param {number[][]} b | ||
* @returns {number[][]} | ||
*/ | ||
export function dist(a, b) { | ||
let og = []; | ||
for (let i = 0; i < a.length; ++i) { | ||
let res = []; | ||
for (let j = 0; j < b[0].length; ++j) { | ||
let summed = 0; | ||
for (let k = 0; k < a[0].length; ++k) { | ||
summed += (a[i][k] - b[k][j]) ** 2; | ||
} | ||
res.push(summed); | ||
} | ||
og.push(res); | ||
} | ||
return og; | ||
} | ||
|
||
/** | ||
* @param {number[][]} d | ||
* @returns {number[]} | ||
*/ | ||
export function argmin(d) { | ||
let res = Array(d.length); | ||
for (let i = 0; i < d.length; ++i) { | ||
let min = Infinity; | ||
let minIndex = -1; | ||
for (let j = 0; j < d[0].length; ++j) { | ||
if (d[i][j] < min) { | ||
min = d[i][j]; | ||
minIndex = j; | ||
} | ||
} | ||
res[i] = minIndex; | ||
} | ||
return res; | ||
} |