Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement on Spot batch VLM predicate evaluation pipeline #285

Open
wants to merge 48 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
0824ae7
try to visualize sokoban - may need nore neat way
lf-zhao Apr 23, 2024
bf4000d
add assert for fast downward planner in running once func
lf-zhao Apr 23, 2024
19d9d53
try to use `rich` package for more structured console output!
lf-zhao Apr 23, 2024
feaa264
upload a naive way to store images
lf-zhao Apr 25, 2024
dd1fbf6
debug
lf-zhao Apr 25, 2024
32a06a3
upload - manual copy from Nishanth's VLM interface in LIS predicators…
lf-zhao Apr 25, 2024
ee3f634
add OpenAI vlm - in progress
lf-zhao Apr 25, 2024
a1a67fd
update config setting for using vlm
lf-zhao Apr 25, 2024
5d9f12e
add package
lf-zhao Apr 26, 2024
cf3dbf7
another missed one
lf-zhao Apr 26, 2024
3001f32
manually add Nishanth's new pretrained model interface for now, see L…
lf-zhao Apr 29, 2024
4d47211
add new OpenAI VLM class, add example to use
lf-zhao Apr 29, 2024
e45a0e9
add a flag for caching
lf-zhao Apr 29, 2024
8786dcd
now the example working - fix requesting vision messages, update test
lf-zhao Apr 29, 2024
71d0b3e
update; add choosing img detail quality
lf-zhao Apr 29, 2024
112b690
include the test image I used for now, not sure what I should do with…
lf-zhao Apr 29, 2024
f884df7
remove original vlm interface, already merged into latest pretrained …
lf-zhao Apr 29, 2024
8c17c42
Merge branch 'refs/heads/master' into lis-spot/implement-vlm-predicat…
lf-zhao Apr 30, 2024
aec70de
found a way to use VLM to evaluate; add current images and also visib…
lf-zhao Apr 30, 2024
94e6a4c
found a way to use VLM to evaluate; check if visible in current scene…
lf-zhao Apr 30, 2024
3ee2ba9
update State struct; adding to Spot specific subclass doesn't work, n…
lf-zhao Apr 30, 2024
1abf488
add detail option
lf-zhao May 1, 2024
1c82c44
working; implement On predicate with VLM classifier pipeline! add cal…
lf-zhao May 1, 2024
eeb1583
make a separate function for vlm predicate classifier evaluation
lf-zhao May 1, 2024
acbdb0a
add test
lf-zhao May 3, 2024
01498f6
update example, move to test, move img
lf-zhao May 3, 2024
0774354
remove
lf-zhao May 3, 2024
68ef57d
format
lf-zhao May 4, 2024
6560a94
update
lf-zhao May 4, 2024
1596f68
batch VLM classifier working on Spot!! add field to State, add VLMPre…
lf-zhao May 8, 2024
98ae2a9
batch VLM classifier eval: add vlm predicates fields to observation
lf-zhao May 8, 2024
2683a6b
batch VLM classifier eval: function on batch query and parse
lf-zhao May 8, 2024
b0b7568
batch VLM classifier eval: provide VLM predicates to object finding, …
lf-zhao May 8, 2024
b4718f7
batch VLM classifier eval: add VLM predicate fields to state+obs, bui…
lf-zhao May 8, 2024
3ca9032
formatting
lf-zhao May 8, 2024
1ab3d36
remove some comments
lf-zhao May 8, 2024
11f91ff
remove some comments
lf-zhao May 8, 2024
c98e2fe
fix, add tenacity
lf-zhao May 8, 2024
4973f1f
fix structs
lf-zhao May 8, 2024
26239e8
more fix
lf-zhao May 8, 2024
e78e342
update
lf-zhao May 9, 2024
ce43a76
some clean
lf-zhao May 9, 2024
93fb574
fix no VLM case
lf-zhao May 9, 2024
cb0e1ee
add predicate prompt; fix and clean
lf-zhao May 9, 2024
63a1429
add predicate prompt & some logging; fix and clean
lf-zhao May 9, 2024
ba2575f
overwrite vlm predicate classifier; reformat
lf-zhao May 9, 2024
a11e5a1
update
lf-zhao May 9, 2024
8ba4c5d
update vlm query in obj finding
lf-zhao May 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions predicators/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def create_arg_parser(env_required: bool = True,
parser.add_argument("--experiment_id", default="", type=str)
parser.add_argument("--load_experiment_id", default="", type=str)
parser.add_argument("--log_file", default="", type=str)
parser.add_argument("--log_rich", default="true", type=str)
parser.add_argument("--use_gui", action="store_true")
parser.add_argument('--debug',
action="store_const",
Expand Down
218 changes: 178 additions & 40 deletions predicators/envs/spot_env.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions predicators/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,14 @@ def main() -> None:
args = utils.parse_args()
utils.update_config(args)
str_args = " ".join(sys.argv)
# Log to stderr.
handlers: List[logging.Handler] = [logging.StreamHandler()]
# Log to stderr or use `rich` package for more structured output.
handlers: List[logging.Handler] = []
if CFG.log_rich:
from rich.logging import RichHandler
handlers.append(RichHandler())
else:
handlers.append(logging.StreamHandler())

if CFG.log_file:
handlers.append(logging.FileHandler(CFG.log_file, mode='w'))
logging.basicConfig(level=CFG.loglevel,
Expand Down
26 changes: 24 additions & 2 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,18 @@ def _update_state_from_observation(self, observation: Observation) -> None:
for obj in observation.objects_in_view:
self._lost_objects.discard(obj)

# NOTE: This is only used when using VLM for predicate evaluation
# NOTE: Performance aspect should be considered later
if CFG.spot_vlm_eval_predicate:
# Add current Spot images to the state if needed
self._camera_images = observation.images
self._vlm_atom_dict = observation.vlm_atom_dict
self._vlm_predicates = observation.vlm_predicates
else:
self._camera_images = None
self._vlm_atom_dict = None
self._vlm_predicates = None

def _create_state(self) -> State:
if self._waiting_for_observation:
return DefaultState
Expand Down Expand Up @@ -281,9 +293,19 @@ def _create_state(self) -> State:
# logging.info("Simulator state:")
# logging.info(simulator_state)

# Prepare the current images from observation
camera_images = self._camera_images if CFG.spot_vlm_eval_predicate else None

# Now finish the state.
state = _PartialPerceptionState(percept_state.data,
simulator_state=simulator_state)
state = _PartialPerceptionState(
percept_state.data,
simulator_state=simulator_state,
camera_images=camera_images,
visible_objects=self._objects_in_view,
vlm_atom_dict=self._vlm_atom_dict,
vlm_predicates=self._vlm_predicates,
)
# DEBUG - look into dataclass field init - why warning

return state

Expand Down
3 changes: 3 additions & 0 deletions predicators/planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,9 @@ def run_task_plan_once(
raise PlanningFailure(
"Skeleton produced by A-star exceeds horizon!")
elif "fd" in CFG.sesame_task_planner: # pragma: no cover
# Run Fast Downward. See the instructions in the docstring of `_sesame_plan_with_fast_downward`
assert "FD_EXEC_PATH" in os.environ, \
"Please follow the instructions in the docstring of this method!"
fd_exec_path = os.environ["FD_EXEC_PATH"]
exec_str = os.path.join(fd_exec_path, "fast-downward.py")
timeout_cmd = "gtimeout" if sys.platform == "darwin" \
Expand Down
114 changes: 112 additions & 2 deletions predicators/pretrained_model_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@
"""

import abc
import base64
import logging
import os
import time
from typing import List, Optional
from io import BytesIO
from typing import Dict, List, Optional

import cv2
import google
import google.generativeai as genai
import imagehash
import openai
import PIL.Image
from tenacity import retry, stop_after_attempt, wait_random_exponential

from predicators.settings import CFG

Expand Down Expand Up @@ -74,7 +78,7 @@ def sample_completions(self,
model_id = self.get_id()
prompt_id = hash(prompt)
config_id = f"{temperature}_{seed}_{num_completions}_" + \
f"{stop_token}"
f"{stop_token}"
# If the temperature is 0, the seed does not matter.
if temperature == 0.0:
config_id = f"most_likely_{num_completions}_{stop_token}"
Expand Down Expand Up @@ -249,3 +253,109 @@ def _sample_completions(
time.sleep(3.0)
response.resolve()
return [response.text]


class OpenAIVLM(VisionLanguageModel):
"""Interface for OpenAI's VLMs, including GPT-4 Turbo (and preview
versions)."""

def __init__(self, model_name: str = "gpt-4-turbo", detail: str = "auto"):
"""Initialize with a specific model name."""
self.model_name = model_name
self.detail = detail
assert "OPENAI_API_KEY" in os.environ
openai.api_key = os.getenv("OPENAI_API_KEY")

def prepare_vision_messages(self,
images: List[PIL.Image.Image],
prefix: Optional[str] = None,
suffix: Optional[str] = None,
image_size: Optional[int] = 512,
detail: str = "auto") -> List[Dict[str, str]]:
"""Prepare text and image messages for the OpenAI API."""
content = []

if detail is None or detail == "auto":
detail = self.detail

if prefix:
content.append({"text": prefix, "type": "text"})

assert images
assert detail in ["auto", "low", "high"]
for img in images:
img_resized = img
if image_size:
factor = image_size / max(img.size)
img_resized = img.resize(
(int(img.size[0] * factor), int(img.size[1] * factor)))

# Convert the image to PNG format and encode it in base64
buffer = BytesIO()
img_resized.save(buffer, format="PNG")
buffer_bytes = buffer.getvalue()
frame = base64.b64encode(buffer_bytes).decode("utf-8")

content.append({
"image_url": {
"url": f"data:image/png;base64,{frame}",
"detail": "auto"
},
"type": "image_url"
})

if suffix:
content.append({"text": suffix, "type": "text"})

return [{"role": "user", "content": content}]

@retry(wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6))
def call_openai_api(self,
messages: list,
model: str = "gpt-4",
seed: Optional[int] = None,
max_tokens: int = 32,
temperature: float = 0.2,
verbose: bool = False) -> str:
"""Make an API call to OpenAI."""
client = openai.OpenAI()
completion = client.chat.completions.create(
model=model,
messages=messages,
seed=seed,
max_tokens=max_tokens,
temperature=temperature,
)
if verbose:
print(f"OpenAI API response: {completion}")
assert len(completion.choices) == 1
return completion.choices[0].message.content

def get_id(self) -> str:
"""Get an identifier for the model."""
return f"OpenAI-{self.model_name}"

def _sample_completions(
self,
prompt: str,
imgs: Optional[List[PIL.Image.Image]],
temperature: float,
seed: int,
stop_token: Optional[str] = None,
num_completions: int = 1,
max_tokens: int = 512,
) -> List[str]:
"""Query the model and get responses."""
assert imgs is not None
messages = self.prepare_vision_messages(prefix=prompt,
images=imgs,
detail="auto")
responses = [
self.call_openai_api(messages,
model=self.model_name,
max_tokens=max_tokens,
temperature=temperature)
for _ in range(num_completions)
]
return responses
6 changes: 6 additions & 0 deletions predicators/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class GlobalSettings:
# your call to utils.reset_config().
render_state_dpi = 150
approach_wrapper = None
# Use VLMs to evaluate some spatial predicates in visual environment,
# e.g., Sokoban. Still work in progress.
enable_vlm_eval_predicate = False

# cover_multistep_options env parameters
cover_multistep_action_limits = [-np.inf, np.inf]
Expand Down Expand Up @@ -178,6 +181,9 @@ class GlobalSettings:
spot_run_dry = False
spot_use_perfect_samplers = False # for debugging
spot_sweep_env_goal_description = "get the objects into the bucket"
# Evaluate some predicates with VLM; need additional setup; WIP
spot_vlm_eval_predicate = False
vlm_eval_verbose = False

# pddl blocks env parameters
pddl_blocks_procedural_train_min_num_blocks = 3
Expand Down
20 changes: 11 additions & 9 deletions predicators/spot_utils/perception/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,15 +473,17 @@ def get_random_mask_pixel_from_artifacts(
mask_idx = rng.choice(len(pixels_in_mask))
pixel_tuple = (pixels_in_mask[1][mask_idx], pixels_in_mask[0][mask_idx])
# Uncomment to plot the grasp pixel being selected!
# rgb_img = artifacts["language"]["rgbds"][camera_name].rgb
# _, axes = plt.subplots()
# axes.imshow(rgb_img)
# axes.add_patch(
# plt.Rectangle((pixel_tuple[0], pixel_tuple[1]), 5, 5, color='red'))
# plt.tight_layout()
# outdir = Path(CFG.spot_perception_outdir)
# plt.savefig(outdir / "grasp_pixel.png", dpi=300)
# plt.close()
"""
rgb_img = artifacts["language"]["rgbds"][camera_name].rgb
_, axes = plt.subplots()
axes.imshow(rgb_img)
axes.add_patch(
plt.Rectangle((pixel_tuple[0], pixel_tuple[1]), 5, 5, color='red'))
plt.tight_layout()
outdir = Path(CFG.spot_perception_outdir)
plt.savefig(outdir / "grasp_pixel.png", dpi=300)
plt.close()
"""
return pixel_tuple


Expand Down
Loading
Loading