Skip to content

Commit

Permalink
Dirty commit!
Browse files Browse the repository at this point in the history
  • Loading branch information
jlegrand62 committed Jul 2, 2024
1 parent bc5783d commit e243600
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 66 deletions.
53 changes: 35 additions & 18 deletions src/plantdb/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,37 @@ def _try_has_file(task, file):

return scan_info

def get_file_uri(scan, fileset, file):
"""Return the URI for the corresponding `scan/fileset/file` tree.
Parameters
----------
scan : plantdb.fsdb.Scan or str
A ``Scan`` instance or the name of the scan dataset.
fileset : plantdb.fsdb.Fileset or str
A ``Fileset`` instance or the name of the fileset.
file : plantdb.fsdb.File or str
A ``File`` instance or the name of the file.
Returns
-------
str
The URI for the corresponding `scan/fileset/file` tree.
"""
from plantdb.fsdb import Scan
from plantdb.fsdb import Fileset
from plantdb.fsdb import File
scan_id = scan.id if isinstance(scan, Scan) else scan
fileset_id = fileset.id if isinstance(fileset, Fileset) else fileset
file_name = file.path().name if isinstance(file, File) else file
return f"/files/{scan_id}/{fileset_id}/{file_name}"

filesUri_task_mapping = {
"PointCloud": "pointCloud",
"TriangleMesh": "mesh",
"CurveSkeleton": "skeleton",
"TreeGraph": "tree",
}

def get_scan_data(scan):
"""Get the scan information and data.
Expand Down Expand Up @@ -338,27 +369,13 @@ def get_scan_data(scan):
scan_data = get_scan_info(scan)
img_fs = scan.get_fileset(task_fs_map['images'])

def _get_file_uri(scan, fileset, file):
return f"/files/{scan.id}/{fileset.id}/{file.path().name}"

# Get the paths to data files:
scan_data["filesUri"] = {}
## Get the URI (file path) to the output of the `PointCloud` task:
if scan_data["hasPointCloud"]:
fs = scan.get_fileset(task_fs_map['PointCloud'])
scan_data["filesUri"]["pointCloud"] = _get_file_uri(scan, fs, fs.get_file('PointCloud'))
## Get the URI (file path) to the output of the `TriangleMesh` task:
if scan_data["hasMesh"]:
fs = scan.get_fileset(task_fs_map['TriangleMesh'])
scan_data["filesUri"]["mesh"] = _get_file_uri(scan, fs, fs.get_file('TriangleMesh'))
## Get the URI (file path) to the output of the `CurveSkeleton` task:
if scan_data["hasSkeleton"]:
fs = scan.get_fileset(task_fs_map['CurveSkeleton'])
scan_data["filesUri"]["skeleton"] = _get_file_uri(scan, fs, fs.get_file('CurveSkeleton'))
## Get the URI (file path) to the output of the `TreeGraph` task:
if scan_data["hasTreeGraph"]:
fs = scan.get_fileset(task_fs_map['TreeGraph'])
scan_data["filesUri"]["tree"] = _get_file_uri(scan, fs, fs.get_file('TreeGraph'))
for task, uri_key in filesUri_task_mapping.items():
if scan_data[f"has{task}"]:
fs = scan.get_fileset(task_fs_map[task])
scan_data["filesUri"][uri_key] = get_file_uri(scan, fs, fs.get_file(task))

# Load some of the data:
scan_data["data"] = {}
Expand Down
223 changes: 175 additions & 48 deletions src/plantdb/rest_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""
This module regroup the client-side methods of the REST API.
"""
import json
from io import BytesIO
from pathlib import Path

Expand All @@ -34,12 +35,39 @@
from PIL import Image
from plyfile import PlyData

from plantdb.rest_api import filesUri_task_mapping
from plantdb.rest_api import get_file_uri

#: Default URL to REST API is 'localhost':
REST_API_URL = "127.0.0.1"
# Default port to REST API:
REST_API_PORT = 5000


def base_url(host=REST_API_URL, port=REST_API_PORT):
"""Format the URL for the PlantDB REST API at given host and port.
Parameters
----------
host : str
The hostname or IP address of the PlantDB REST API server. Defaults to ``"127.0.0.1"``.
port : str
The port number of the PlantDB REST API server. Defaults to ``5000``.
Returns
-------
str
The formatted URL.
Examples
--------
>>> from plantdb.rest_api_client import base_url
>>> base_url()
'http://127.0.0.1:5000'
"""
return f"http://{host}:{port}"


def test_db_availability(host=REST_API_URL, port=REST_API_PORT):
"""Test the REST API server availability.
Expand All @@ -61,12 +89,13 @@ def test_db_availability(host=REST_API_URL, port=REST_API_PORT):
>>> test_db_availability()
"""
try:
requests.get(url=f"http://{host}:{port}/scans")
requests.get(url=f"{base_url(host, port)}/scans")
except requests.exceptions.ConnectionError:
return False
else:
return True


def get_scans_info(host=REST_API_URL, port=REST_API_PORT):
"""Retrieve the information dictionary for all scans from the PlantDB REST API.
Expand All @@ -88,7 +117,7 @@ def get_scans_info(host=REST_API_URL, port=REST_API_PORT):
>>> # This example requires the PlantDB REST API to be active (`fsdb_rest_api --test` from plantdb library)
>>> get_scans_info()
"""
return requests.get(url=f"http://{host}:{port}/scans").json()
return requests.get(url=f"{base_url(host, port)}/scans").json()


def parse_scans_info(host=REST_API_URL, port=REST_API_PORT):
Expand Down Expand Up @@ -149,7 +178,7 @@ def get_scan_data(scan_id, host=REST_API_URL, port=REST_API_PORT):
>>> print(scan_data['hasColmap'])
False
"""
return requests.get(url=f"http://{host}:{port}/scans/{scan_id}").json()
return requests.get(url=f"{base_url(host, port)}/scans/{scan_id}").json()


def list_scan_names(host=REST_API_URL, port=REST_API_PORT):
Expand Down Expand Up @@ -203,7 +232,7 @@ def scan_preview_image_url(scan_id, host=REST_API_URL, port=REST_API_PORT, size=
thumb_uri = get_scan_data(scan_id, host, port)["thumbnailUri"]
if size != "thumb":
thumb_uri = thumb_uri.replace("size=thumb", f"size={size}")
return f"http://{host}:{port}{thumb_uri}"
return f"{base_url(host, port)}{thumb_uri}"


def scan_image_url(scan_id, fileset_id, file_id, size='orig', host=REST_API_URL, port=REST_API_PORT):
Expand Down Expand Up @@ -233,7 +262,7 @@ def scan_image_url(scan_id, fileset_id, file_id, size='orig', host=REST_API_URL,
str
The URL to an image of a scan dataset and task fileset.
"""
return f"http://{host}:{port}/image/{scan_id}/{fileset_id}/{file_id}?size={size}"
return f"{base_url(host, port)}/image/{scan_id}/{fileset_id}/{file_id}?size={size}"


def get_scan_image(scan_id, fileset_id, file_id, size='orig', host=REST_API_URL, port=REST_API_PORT):
Expand Down Expand Up @@ -317,9 +346,8 @@ def list_task_images_uri(dataset_name, task_name='images', size='orig', **api_kw
scan_info = get_scan_data(dataset_name, **api_kwargs)
tasks_fileset = scan_info["tasks_fileset"]
images = scan_info["images"]
base_url = f"http://{api_kwargs['host']}:{api_kwargs['port']}"
return [base_url + f"/image/{dataset_name}/{tasks_fileset[task_name]}/{Path(img).stem}?size={size}" for img in
images]
url = base_url(**api_kwargs)
return [url + f"/image/{dataset_name}/{tasks_fileset[task_name]}/{Path(img).stem}?size={size}" for img in images]


def get_images_from_task(dataset_name, task_name='images', size='orig', **api_kwargs):
Expand Down Expand Up @@ -353,7 +381,7 @@ def get_images_from_task(dataset_name, task_name='images', size='orig', **api_kw
Examples
--------
>>> from plantdb.rest_api_client import get_images_from_task
>>> images = get_images_from_task('real_plant')
>>> images = get_images_from_task('real_plant', host='127.0.0.1', port='5000')
>>> print(len(images))
60
>>> img1 = images[0]
Expand All @@ -366,56 +394,155 @@ def get_images_from_task(dataset_name, task_name='images', size='orig', **api_kw
return images


def _ply_pcd_to_array(ply_pcd):
"""Convert the `PlyData` into an XYZ array of vertex coordinates.
def _ply_vertex_to_array(data):
"""Convert the `PlyData` 'vertex' data into an XYZ array of vertex coordinates.
Parameters
----------
ply_pcd : PlyData
data : PlyData
The `PlyData` object to be converted as numpy array.
Returns
-------
numpy.ndarray
The XYZ array of vertex coordinates.
"""
return np.array([ply_pcd['vertex']['x'], ply_pcd['vertex']['y'], ply_pcd['vertex']['z']]).T
return np.array([data['vertex']['x'], data['vertex']['y'], data['vertex']['z']]).T

def _ply_face_to_array(data):
"""Convert the `PlyData` 'face' data into an XYZ array of triangle coordinates.
def get_dataset_data(dataset_name, **api_kwargs):
""""""
base_url = f"http://{api_kwargs['host']}:{api_kwargs['port']}"
scan_info = get_scan_data(dataset_name, **api_kwargs)
Parameters
----------
data : PlyData
The `PlyData` object to be converted as numpy array.
# Get pointcloud data from RET API:
## Use the file URI to get the fileset id & file id:
pcd_fileset_id = scan_info["tasks_fileset"].get('PointCloud', None)
pcd_file_id = "PointCloud"
## Use the REST API to stream the pointcloud:
pcd_file = requests.get(base_url + f"/pointcloud/{dataset_name}/{pcd_fileset_id}/{pcd_file_id}").content
Returns
-------
numpy.ndarray
The XYZ array of triangle coordinates.
"""
return np.array([d for d in data['face'].data['vertex_indices']]).T

def parse_requests_pcd(data):
"""Parse a requests content, should be from a PointCloud task source.
Parameters
----------
data : buffer
The data source from a requests content.
Returns
-------
numpy.ndarray
The parsed pointcloud with vertex coordinates sorted as XYZ.
"""
## Read the pointcloud PLY as a `PlyData`:
ply_pcd = PlyData.read(BytesIO(data))
## Convert the `PlyData`:
return _ply_vertex_to_array(ply_pcd)


def parse_requests_mesh(data):
"""Parse a requests content, should be from a TriangleMesh task source.
Parameters
----------
data : buffer
The data source from a requests content.
Returns
-------
dict
The parsed triangular mesh with two entries: 'vertices' for vertex coordinates and 'triangles' for triangle coordinates.
"""
## Read the PLY as a `PlyData`:
ply_pcd = PlyData.read(BytesIO(pcd_file))
## Convert the `PlyData` into an XYZ array of vertex coordinates:
pcd_array = _ply_pcd_to_array(ply_pcd)

# Get mesh data from RET API:
## Use the file URI to get the fileset id & file id:
mesh_uri = scan_info["filesUri"]["mesh"]
mesh_fileset_id, mesh_file_id = Path(mesh_uri).parts[-2:]
mesh_file_id = Path(mesh_file_id).stem
## Use the REST API to stream the pointcloud:
mesh_file = requests.get(base_url + f"/mesh/{dataset_name}/{mesh_fileset_id}/{mesh_file_id}").content
mesh_data = PlyData.read(BytesIO(mesh_file))
mesh_data = None

# Get tree data from RET API:
## Use the file URI to get the fileset id & file id:
tree_uri = scan_info["filesUri"]["tree"]
## Use the REST API to stream the pointcloud:
tree_file = requests.get(base_url + tree_uri).content
tree_data = None # tree files are NetworkX graphs!

fruit_dir = None
stem_dir = None
return {"PointCloud": pcd_array, "TriangleMesh": mesh_data, "TreeGraph": tree_data,
"FruitDirection": fruit_dir, "StemDirection": stem_dir}
mesh_data = PlyData.read(BytesIO(data))
## Convert the `PlyData`:
return {"vertices": _ply_vertex_to_array(mesh_data),
"triangles": _ply_face_to_array(mesh_data)}


def parse_requests_skeleton(data):
"""Parse a requests content, should be from a CurveSkeleton task source.
Parameters
----------
data : buffer
The data source from a requests content.
Returns
-------
dict
The parsed skeleton with two entries: 'points' for points coordinates and 'lines' joining them.
"""
return json.loads(data)


def parse_requests_tree(data):
"""Parse a requests content, should be from a TreeGraph task source.
Parameters
----------
data : buffer
The data source from a requests content.
Returns
-------
networkx.Graph
The loaded (tree) graph object.
"""
import pickle
return pickle.load(BytesIO(data))


def parse_requests_json(data):
"""Parse a requests content, should be from a AnglesAndInternodes task source.
Parameters
----------
data : buffer
The data source from a requests content.
Returns
-------
dict
The full angles and internodes dictionary with 'angles', 'internodes', '' & '' entries.
"""
return json.loads(data)


PARSER_DICT = {
"PointCloud": parse_requests_pcd,
"TriangleMesh": parse_requests_mesh,
"CurveSkeleton": parse_requests_skeleton,
"TreeGraph": parse_requests_tree,
"AnglesAndInternodes": parse_requests_json,
}

EXT_PARSER_DICT = {
"json": parse_requests_json,
}

def parse_task_requests_data(task, data, extension=None):
"""The task data parser, behave according to the source and default to JSON parser."""
if extension is not None:
data_parser = EXT_PARSER_DICT[extension]
else:
data_parser = PARSER_DICT.get(task, parse_requests_json)
return data_parser(data)


def get_task_data(dataset_name, task, filename=None, **api_kwargs):
url = base_url(**api_kwargs)
scan_info = get_scan_data(dataset_name, **api_kwargs)
# Get data from `File` resource of REST API:
ext = None
if filename is None:
file_uri = scan_info["filesUri"][filesUri_task_mapping[task]]
else:
_, ext = Path(filename).suffix.split('.')
file_uri = get_file_uri(dataset_name, scan_info["tasks_fileset"][task], filename)
data = requests.get(url + file_uri).content

return parse_task_requests_data(task, data, ext)

0 comments on commit e243600

Please sign in to comment.