Skip to content

Commit

Permalink
Merge pull request #43 from h-munakata/muna/amr_api_demo
Browse files Browse the repository at this point in the history
Add AMR demo of inference API
  • Loading branch information
awkrail authored Oct 22, 2024
2 parents c2d1d3e + 8863f08 commit 7ae1a79
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
41 changes: 41 additions & 0 deletions api_example/amr_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Copyright $today.year LY Corporation
LY Corporation licenses this file to you under the Apache License,
version 2.0 (the "License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at:
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations
under the License.
"""
import os
import subprocess
import torch

from lighthouse.models import QDDETRPredictor
from typing import Dict, List, Optional

def load_weights(weight_dir: str) -> None:
if not os.path.exists(os.path.join(weight_dir, 'clap_qd_detr_clotho-moment.ckpt')):
command = 'wget -P gradio_demo/weights/ https://zenodo.org/records/13961029/files/clap_qd_detr_clotho-moment.ckpt'
subprocess.run(command, shell=True)

# use GPU if available
device: str = 'cpu'
weight_dir: str = 'gradio_demo/weights'
weight_path: str = os.path.join(weight_dir, 'clap_qd_detr_clotho-moment.ckpt')
load_weights(weight_dir)
model: QDDETRPredictor = QDDETRPredictor(weight_path, device=device, feature_name='clap')

# encode audio features
model.encode_audio('api_example/1a-ODBWMUAE.wav')

# moment retrieval
query: str = 'Water cascades down from a waterfall.'
prediction: Optional[Dict[str, List[float]]] = model.predict(query)
print(prediction)
3 changes: 2 additions & 1 deletion api_example/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def load_weights(weight_dir: str) -> None:
device: str = 'cpu'
weight_dir: str = 'gradio_demo/weights'
weight_path: str = os.path.join(weight_dir, 'clip_slowfast_cg_detr_qvhighlight.ckpt')
load_weights(weight_dir)
model: CGDETRPredictor = CGDETRPredictor(weight_path, device=device, feature_name='clip_slowfast',
slowfast_path='SLOWFAST_8x8_R50.pkl', pann_path=None)

Expand All @@ -44,4 +45,4 @@ def load_weights(weight_dir: str) -> None:
# moment retrieval & highlight detection
query: str = 'A woman wearing a glass is speaking in front of the camera'
prediction: Optional[Dict[str, List[float]]] = model.predict(query)
print(prediction)
print(prediction)

0 comments on commit 7ae1a79

Please sign in to comment.