Skip to content

Commit

Permalink
image selection
Browse files Browse the repository at this point in the history
  • Loading branch information
xnought committed Sep 25, 2024
1 parent 97940b7 commit 199ddbf
Show file tree
Hide file tree
Showing 13 changed files with 586 additions and 28 deletions.
3 changes: 3 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,8 @@
"svelte": "^4.2.18",
"tailwindcss": "^3.4.9",
"vite": "^5.4.1"
},
"dependencies": {
"@tensorflow/tfjs": "^4.21.0"
}
}
359 changes: 347 additions & 12 deletions pnpm-lock.yaml

Large diffs are not rendered by default.

Binary file added public/images/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added public/images/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added public/images/3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added public/images/4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added public/images/5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added public/images/7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
67 changes: 62 additions & 5 deletions src/App.svelte
Original file line number Diff line number Diff line change
@@ -1,12 +1,69 @@
<script>
import "./app.css";
import Header from "./components/Header.svelte";
import "./app.css";
import Header from "./lib/Header.svelte";
import ImageSelector from "./lib/ImageSelector.svelte";
import { onDestroy, onMount } from "svelte";
import { loadEmbeddings, loadModels } from "./load";
import * as tf from "@tensorflow/tfjs";
import * as tfu from "./tfUtils";
class VectorQuantizer {
constructor(embeddings) {
this.embeddings = embeddings;
this.numEmbed = this.embeddings.shape[1];
this.embedDim = this.embeddings.shape[0];
}
/**
* @param {tf.Tensor} x
*/
predict(x) {
const xShape = x.shape;
const features = x.reshape([-1, this.embedDim]);
// quantization step
const idxs = tfu.tfDist(features, this.embeddings).argMin(1);
const selectColumns = tf.oneHot(idxs, this.numEmbed);
const quantized = tf.matMul(
selectColumns,
this.embeddings.transpose()
);
return quantized.reshape(xShape);
}
}
const images = [1, 2, 3, 4, 5, 7].map((d) => `images/${d}.png`);
let rawImages;
let selectedImage = "images/1.png";
$: console.log(selectedImage);
onMount(async () => {
const embeddings = await loadEmbeddings();
tf.tidy(() => {
const tensorEmbeddings = tf.tensor(embeddings);
const vq = new VectorQuantizer(tensorEmbeddings);
});
// tf.tidy(() => {
// loadModels().then(([encoder, decoder]) => {
// const input = tf.ones([1, 28, 28, 1]);
// const out = encoder.predict(input);
// const decInput = tf.ones([1, 7, 7, 16]);
// const decOut = decoder.predict(decInput);
// });
// });
});
</script>

<Header />
<main>
...
<main class="p-5">
<div class="mb-2 flex gap-2 items-center">
<ImageSelector imageUrls={images} bind:selectedUrl={selectedImage} />
</div>
</main>


<style></style>
20 changes: 9 additions & 11 deletions src/components/Header.svelte → src/lib/Header.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,28 @@
import { Button } from "flowbite-svelte";
import { GithubSolid, FilePdfSolid } from "flowbite-svelte-icons";
const color = "alternative";
const color = "alternative";
</script>

<nav>
<div id="inner">
<div class="left">
<img src="logo.svg" alt="logo" />
<span id="title-desc">
Interact with a <b>V</b>ector <b>Q</b>uantized <b>V</b>ariational <b>A</b>uto<b>e</b>ncoder
(<b>VQ-VAE</b>) in your browser!
</span>
<!-- <span id="title-desc">
Interact with a <b>VQ-VAE</b> in your browser!
</span> -->
</div>
<div class="right">
<Button
size="xs"
href="https://github.com/xnought/vq-vae-explainer"
target="_blank"
{color}
outline
target="_blank"
{color}
outline
>Code
<GithubSolid size="md" class="ml-1" />
</Button>
<!--
<!--
<Button
size="xs"
href="https://arxiv.org/abs/2409.09011"
Expand All @@ -41,7 +40,7 @@

<style>
nav {
box-shadow: 0px 0px 3px 2px rgba(0,0,0,0.1);
box-shadow: 0px 0px 3px 2px rgba(0, 0, 0, 0.1);
}
#inner {
padding: 10px;
Expand Down Expand Up @@ -71,4 +70,3 @@
cursor: pointer;
} */
</style>

43 changes: 43 additions & 0 deletions src/lib/ImageSelector.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<script>
export let imageUrls = [];
export let selectedUrl;
const width = 40;
</script>

<div class="flex gap-2">
{#each imageUrls as src, i}
<img
{src}
alt="begone yellow squigly lines"
{width}
class:lined={selectedUrl === src}
on:click={() => {
selectedUrl = src;
}}
/>
{/each}
<div
class="s"
style="width: {width}px; height: {width}px; background: black; opacity: {selectedUrl ===
'clear'
? 1
: 0.4}; {selectedUrl === 'clear'
? 'box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.5);'
: ''};"
on:click={() => (selectedUrl = "clear")}
></div>
</div>

<style>
.lined {
opacity: 1;
box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.5);
}
.s {
cursor: pointer;
}
img {
cursor: pointer;
opacity: 0.4;
}
</style>
69 changes: 69 additions & 0 deletions src/load.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import * as tf from "@tensorflow/tfjs";

const filepath = "tfjs";

export async function loadModels() {
const encoder = await tf.loadGraphModel(`${filepath}/encoder/model.json`);
const decoder = await tf.loadGraphModel(`${filepath}/decoder/model.json`);
return [encoder, decoder];
}

/**
* @param {string} txt
*/
function parseNumpyTxt(txt) {
const rows = txt.split("\n");
let result = [];
for (let i = 0; i < rows.length - 1; ++i) {
const r = rows[i].split(" ");
const floats = r.map((d) => Number(d));
result.push(floats);
}
return result;
}

export async function loadEmbeddings(name = "embeddings_dim16_num32.txt") {
const out = await (await fetch(`${filepath}/${name}`)).text();
const parsed = parseNumpyTxt(out);
return parsed;
}

export function loadImage(url) {
const img = new Image();
return new Promise((res, rej) => {
img.src = url;
img.onload = () => res(img);
img.onerror = rej;
});
}

export async function loadImageFull(url) {
const img = await loadImage(url);
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
ctx.drawImage(img, 0, 0);
const d = ctx.getImageData(0, 0, img.width, img.height).data;
img.remove();
canvas.remove();
return d;
}

function toGrey(d) {
const result = new Uint8ClampedArray(d.length / 4);
for (let i = 0, j = 0; i < d.length; i += 4, j++) {
result[j] = d[i];
}
return result;
}

export async function fetchAllImages(urls) {
let result = {};
for (let i = 0; i < urls.length; i++) {
const url = urls[i];
const d = await loadImageFull(url);
const g = toGrey(d);
const f32 = new Float32Array(g.length).map((_, i) => g[i] / 255);
result[url] = f32;
}
return result;
}
53 changes: 53 additions & 0 deletions src/tfUtils.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import * as tf from "@tensorflow/tfjs";

/**
* @param {tf.Tensor} a
* @param {tf.Tensor} b
*/
export function tfDist(a, b) {
const a_2 = tf.sum(tf.square(a), 1, true);
const b_2 = tf.sum(tf.square(b), 0, true);
const twoab = tf.mul(2, tf.matMul(a, b));
return a_2.add(b_2).sub(twoab);
}

/**
* @param {number[][]} a
* @param {number[][]} b
* @returns {number[][]}
*/
export function dist(a, b) {
let og = [];
for (let i = 0; i < a.length; ++i) {
let res = [];
for (let j = 0; j < b[0].length; ++j) {
let summed = 0;
for (let k = 0; k < a[0].length; ++k) {
summed += (a[i][k] - b[k][j]) ** 2;
}
res.push(summed);
}
og.push(res);
}
return og;
}

/**
* @param {number[][]} d
* @returns {number[]}
*/
export function argmin(d) {
let res = Array(d.length);
for (let i = 0; i < d.length; ++i) {
let min = Infinity;
let minIndex = -1;
for (let j = 0; j < d[0].length; ++j) {
if (d[i][j] < min) {
min = d[i][j];
minIndex = j;
}
}
res[i] = minIndex;
}
return res;
}

0 comments on commit 199ddbf

Please sign in to comment.