Skip to content

Commit

Permalink
models from huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
cocoa-xu committed Jun 24, 2024
1 parent d8534be commit 94f24f6
Show file tree
Hide file tree
Showing 2 changed files with 366 additions and 0 deletions.
124 changes: 124 additions & 0 deletions src/tflite_beam/contrib/tflite_beam_contrib_huggingface.erl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
-module(tflite_beam_contrib_huggingface).
-export([all_models/0, model_urls/1, download_model/1]).

-define(HUGGINGFACE_BASEURL, "https://huggingface.co/").

repo_download_baseurl(Repo, Branch) ->
?HUGGINGFACE_BASEURL ++ Repo ++ "/resolve/" ++ Branch.

model_urls(#{repo := Repo, files := Files}) ->
RepoBaseURL = repo_download_baseurl(Repo, "main"),
lists:map(fun(File) -> {Repo, File, RepoBaseURL ++ "/" ++ File ++ "?download=true"} end, Files).

download_model(ModelInfo) ->
ModelURLs = model_urls(ModelInfo),
lists:foldl(fun({Repo, File, URL}, Status) ->
case Status of
{ok, Acc} ->
case tflite_beam_utils_downloader:download(URL, Repo, File, false) of
{ok, DestionationFilename} ->
{ok, [DestionationFilename | Acc]};
{error, Reason} ->
{error, Reason}
end;
{error, Reason} ->
{error, Reason}
end
end, {ok, []}, ModelURLs).

all_models() ->
[
#{repo => "qualcomm/ConvNext-Tiny", files => ["ConvNext-Tiny.tflite"], task => image_classification},
#{repo => "qualcomm/DenseNet-121", files => ["DenseNet-121.tflite"], task => image_classification},
#{repo => "qualcomm/EfficientNet-B0", files => ["EfficientNet-B0.tflite"], task => image_classification},
#{repo => "qualcomm/GoogLeNet", files => ["GoogLeNet.tflite"], task => image_classification},
#{repo => "qualcomm/GoogLeNetQuantized", files => ["GoogLeNetQuantized.tflite"], task => image_classification},
#{repo => "qualcomm/HRNetPose", files => ["HRNetPose.tflite"], task => image_classification},
#{repo => "qualcomm/HRNetPoseQuantized", files => ["HRNetPoseQuantized.tflite"], task => image_classification},
#{repo => "qualcomm/Inception-v3", files => ["Inception-v3.tflite"], task => image_classification},
#{repo => "qualcomm/Inception-v3Quantized", files => ["Inception-v3Quantized.tflite"], task => image_classification},
#{repo => "qualcomm/LiteHRNet", files => ["LiteHRNet.tflite"], task => image_classification},
#{repo => "qualcomm/MNASNet05", files => ["MNASNet05.tflite"], task => image_classification},
#{repo => "qualcomm/MediaPipe-Pose-Estimation", files => ["MediaPipePoseDetector.tflite", "MediaPipePoseLandmarkDetector.tflite"], task => image_classification},
#{repo => "qualcomm/MobileNet-v2", files => ["MobileNet-v2.tflite"], task => image_classification},
#{repo => "qualcomm/MobileNet-v2-Quantized", files => ["MobileNet-v2-Quantized.tflite"], task => image_classification},
#{repo => "qualcomm/MobileNet-v3-Large", files => ["MobileNet-v3-Large.tflite"], task => image_classification},
#{repo => "qualcomm/MobileNet-v3-Small", files => ["MobileNet-v3-Small.tflite"], task => image_classification},
#{repo => "qualcomm/OpenAI-Clip", files => ["CLIPTextEncoder.tflite", "CLIPImageEncoder.tflite"], task => image_classification},
#{repo => "qualcomm/OpenPose", files => ["OpenPose.tflite"], task => image_classification},
#{repo => "qualcomm/RegNet", files => ["RegNet.tflite"], task => image_classification},
#{repo => "qualcomm/ResNet18", files => ["ResNet18.tflite"], task => image_classification},
#{repo => "qualcomm/ResNet18Quantized", files => ["ResNet18Quantized.tflite"], task => image_classification},
#{repo => "qualcomm/ResNet50", files => ["ResNet50.tflite"], task => image_classification},
#{repo => "qualcomm/ResNet101", files => ["ResNet101.tflite"], task => image_classification},
#{repo => "qualcomm/ResNet101Quantized", files => ["ResNet101Quantized.tflite"], task => image_classification},
#{repo => "qualcomm/ResNeXt50", files => ["ResNeXt50.tflite"], task => image_classification},
#{repo => "qualcomm/ResNeXt101", files => ["ResNeXt101.tflite"], task => image_classification},
#{repo => "qualcomm/ResNeXt101Quantized", files => ["ResNeXt101Quantized.tflite"], task => image_classification},
#{repo => "qualcomm/Shufflenet-v2", files => ["Shufflenet-v2.tflite"], task => image_classification},
#{repo => "qualcomm/Shufflenet-v2Quantized", files => ["Shufflenet-v2Quantized.tflite"], task => image_classification},
#{repo => "qualcomm/SqueezeNet-1_1", files => ["SqueezeNet-1_1.tflite"], task => image_classification},
#{repo => "qualcomm/SqueezeNet-1_1Quantized", files => ["SqueezeNet-1_1Quantized.tflite"], task => image_classification},
#{repo => "qualcomm/Swin-Base", files => ["Swin-Base.tflite"], task => image_classification},
#{repo => "qualcomm/Swin-Small", files => ["Swin-Small.tflite"], task => image_classification},
#{repo => "qualcomm/Swin-Tiny", files => ["Swin-Tiny.tflite"], task => image_classification},
#{repo => "qualcomm/VIT", files => ["VIT.tflite"], task => image_classification},
#{repo => "qualcomm/WideResNet50", files => ["WideResNet50.tflite"], task => image_classification},
#{repo => "qualcomm/WideResNet50-Quantized", files => ["WideResNet50-Quantized.tflite"], task => image_classification},

#{repo => "qualcomm/AOT-GAN", files => ["AOT-GAN.tflite"], task => image_to_image},
#{repo => "qualcomm/ESRGAN", files => ["ESRGAN.tflite"], task => image_to_image},
#{repo => "qualcomm/LaMa-Dilated", files => ["LaMa-Dilated.tflite"], task => image_to_image},
#{repo => "qualcomm/Real-ESRGAN-General-x4v3", files => ["Real-ESRGAN-General-x4v3.tflite"], task => image_to_image},
#{repo => "qualcomm/Real-ESRGAN-x4plus", files => ["Real-ESRGAN-x4plus.tflite"], task => image_to_image},
#{repo => "qualcomm/SESR-M5", files => ["SESR-M5.tflite"], task => image_to_image},
#{repo => "qualcomm/SESR-M5-Quantized", files => ["SESR-M5-Quantized.tflite"], task => image_to_image},
#{repo => "qualcomm/QuickSRNetLarge", files => ["QuickSRNetLarge.tflite"], task => image_to_image},
#{repo => "qualcomm/QuickSRNetLarge-Quantized", files => ["QuickSRNetLarge-Quantized.tflite"], task => image_to_image},
#{repo => "qualcomm/QuickSRNetMedium", files => ["QuickSRNetMedium.tflite"], task => image_to_image},
#{repo => "qualcomm/QuickSRNetMedium-Quantized", files => ["QuickSRNetMedium-Quantized.tflite"], task => image_to_image},
#{repo => "qualcomm/QuickSRNetSmall", files => ["QuickSRNetSmall.tflite"], task => image_to_image},
#{repo => "qualcomm/QuickSRNetSmall-Quantized", files => ["QuickSRNetSmall-Quantized.tflite"], task => image_to_image},
#{repo => "qualcomm/XLSR", files => ["XLSR.tflite"], task => image_to_image},
#{repo => "qualcomm/XLSR-Quantized", files => ["XLSR-Quantized.tflite"], task => image_to_image},

#{repo => "qualcomm/DeepLabV3-ResNet50", files => ["DeepLabV3-ResNet50.tflite"], task => image_segmentation},
#{repo => "qualcomm/DDRNet23-Slim", files => ["DDRNet23-Slim.tflite"], task => image_segmentation},
#{repo => "qualcomm/FastSam-S", files => ["FastSam-S.tflite"], task => image_segmentation},
#{repo => "qualcomm/FastSam-X", files => ["FastSam-X.tflite"], task => image_segmentation},
#{repo => "qualcomm/FCN_ResNet50", files => ["FCN_ResNet50.tflite"], task => image_segmentation},
#{repo => "qualcomm/FFNet-40S", files => ["FFNet-40S.tflite"], task => image_segmentation},
#{repo => "qualcomm/FFNet-40S-Quantized", files => ["FFNet-40S-Quantized.tflite"], task => image_segmentation},
#{repo => "qualcomm/FFNet-54S", files => ["FFNet-54S.tflite"], task => image_segmentation},
#{repo => "qualcomm/FFNet-54S-Quantized", files => ["FFNet-54S-Quantized.tflite"], task => image_segmentation},
#{repo => "qualcomm/FFNet-78S", files => ["FFNet-78S.tflite"], task => image_segmentation},
#{repo => "qualcomm/FFNet-78S-LowRes", files => ["FFNet-78S-LowRes.tflite"], task => image_segmentation},
#{repo => "qualcomm/FFNet-78S-Quantized", files => ["FFNet-78S-Quantized.tflite"], task => image_segmentation},
#{repo => "qualcomm/FFNet-122NS-LowRes", files => ["FFNet-122NS-LowRes.tflite"], task => image_segmentation},
#{repo => "qualcomm/MediaPipe-Selfie-Segmentation", files => ["MediaPipe-Selfie-Segmentation.tflite"], task => image_segmentation},
#{repo => "qualcomm/Segment-Anything-Model", files => ["SAMDecoder.tflite"], task => image_segmentation},
#{repo => "qualcomm/SINet", files => ["SINet.tflite"], task => image_segmentation},
#{repo => "qualcomm/Unet-Segmentation", files => ["Unet-Segmentation.tflite"], task => image_segmentation},
#{repo => "qualcomm/Yolo-v7", files => ["Yolo-v7.tflite"], task => image_segmentation},
#{repo => "qualcomm/Yolo-v8-Segmentation", files => ["Yolo-v8-Segmentation.tflite"], task => image_segmentation},
#{repo => "qualcomm/YOLOv8-Segmentation", files => ["Yolo-v8-Segmentation.tflite"], task => image_segmentation}

#{repo => "qualcomm/HuggingFace-WavLM-Base-Plus", files => ["HuggingFace-WavLM-Base-Plus.tflite"], task => audio_speech_recognition},
#{repo => "qualcomm/Whisper-Small-En", files => ["WhisperDecoder.tflite", "WhisperEncoder.tflite"], task => audio_speech_recognition},
#{repo => "qualcomm/Whisper-Base-En", files => ["WhisperDecoder.tflite", "WhisperEncoder.tflite"], task => audio_speech_recognition},
#{repo => "qualcomm/Whisper-Tiny-En", files => ["WhisperDecoder.tflite", "WhisperEncoder.tflite"], task => audio_speech_recognition},

#{repo => "qualcomm/DETR-ResNet50", files => ["DETR-ResNet50.tflite"], task => object_detection},
#{repo => "qualcomm/DETR-ResNet50-DC5", files => ["DETR-ResNet50-DC5.tflite"], task => object_detection},
#{repo => "qualcomm/DETR-ResNet101", files => ["DETR-ResNet101.tflite"], task => object_detection},
#{repo => "qualcomm/DETR-ResNet101-DC5", files => ["DETR-ResNet101-DC5.tflite"], task => object_detection},
#{repo => "qualcomm/MediaPipe-Face-Detection", files => ["MediaPipeFaceDetector.tflite", "MediaPipeFaceLandmarkDetector.tflite"], task => object_detection},
#{repo => "qualcomm/MediaPipe-Hand-Detection", files => ["MediaPipeHandDetector.tflite", "MediaPipeHandLandmarkDetector.tflite"], task => object_detection},
#{repo => "qualcomm/Yolo-v6", files => ["Yolo-v6.tflite"], task => object_detection},
#{repo => "qualcomm/Yolo-v8-Detection", files => ["Yolo-v8-Detection.tflite"], task => object_detection},
#{repo => "qualcomm/YOLOv8-Detection", files => ["Yolo-v8-Detection.tflite"], task => object_detection},

#{repo => "qualcomm/Facebook-Denoiser", files => ["Facebook-Denoiser.tflite"], task => audio_to_audio},
#{repo => "qualcomm/TrOCR", files => ["TrOCRDecoder.tflite", "TrOCREncoder.tflite"], task => image_to_text},
#{repo => "qualcomm/StyleGAN2", files => ["StyleGAN2.tflite"], task => unconditional_image_generation}
].
242 changes: 242 additions & 0 deletions src/tflite_beam/utils/tflite_beam_utils_downloader.erl
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
-module(tflite_beam_utils_downloader).
-export([download/4]).

download(URL, CacheSubdir, CacheFilename, ForceDownload) ->
case cache_path(CacheSubdir, CacheFilename, ForceDownload) of
{ok, Exists, DestionationDir, DestionationFilename} ->
if
Exists ->
{ok, DestionationFilename};
true ->
filelib:ensure_dir(DestionationDir),
do_download(URL, DestionationFilename)
end;
{error, Err} ->
{error, Err}
end.

cache_opts() ->
case os:getenv("MIX_XDG") of
false ->
#{};
_ ->
#{os => linux}
end.

maybe_override_by_env(EnvName, Default) ->
case os:getenv(EnvName) of
false ->
Default;
Value ->
Value
end.

mkdir_dir_p(Dir) ->
case file:make_dir(Dir) of
ok ->
ok;
{error, eexist} ->
ok;
{error, Reason} ->
{error, lists:flatten(io_lib:fwrite("Cannot create directory ~s: ~s", [Dir, Reason]))}
end.

cache_basepath() ->
CacheBaseDir = filename:basedir(user_cache, "tflite_beam", cache_opts()),
CacheDir = maybe_override_by_env("TFLITE_BEAM_CACHE_DIR", CacheBaseDir),
case mkdir_dir_p(CacheDir) of
ok ->
{ok, CacheDir};
{error, Reason} ->
{error, lists:flatten(io_lib:fwrite("Cannot create cache directory ~s: ~s", [CacheDir, Reason]))}
end.

cache_path(CacheSubdir, CacheFilename, ForceDownload) ->
case cache_basepath() of
{ok, BaseDir} ->
CacheDir = filename:join(BaseDir, CacheSubdir),
case mkdir_dir_p(CacheDir) of
ok ->
CacheFile = filename:join(CacheDir, CacheFilename),
FileExists = if
ForceDownload -> false;
true -> filelib:is_file(CacheFile)
end,
{ok, FileExists,CacheDir,CacheFile};
_ ->
{error, lists:flatten(io_lib:fwrite("Cannot create cache directory ~s", [CacheDir]))}
end;
{error, Err} ->
{error, Err}
end.

certificate_store() ->
PossibleLocations = [
%% Configured cacertfile
os:getenv("TFLITE_BEAM_CACERT"),

%% Debian/Ubuntu/Gentoo etc.
"/etc/ssl/certs/ca-certificates.crt",

%% Fedora/RHEL 6
"/etc/pki/tls/certs/ca-bundle.crt",

%5 OpenSUSE
"/etc/ssl/ca-bundle.pem",

%% OpenELEC
"/etc/pki/tls/cacert.pem",

%% CentOS/RHEL 7
"/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem",

%% Open SSL on MacOS
"/usr/local/etc/openssl/cert.pem",

%% MacOS & Alpine Linux
"/etc/ssl/cert.pem"
],
CheckExistance = lists:map(fun (F) ->
{filelib:is_file(F), F}
end, PossibleLocations),
ExistingOnes = lists:dropwhile(fun ({X, _}) -> X == false end, CheckExistance),
case length(ExistingOnes) of
Len when Len > 0 ->
{_, Cert} = hd(ExistingOnes),
Cert;
_ ->
io:fwrite("[WARNING] Cannot find CA certificate store in default locations: ~p~n", [PossibleLocations]),
io:fwrite("You can set environment variable TFLITE_BEAM_CACERT to the SSL cert on your system.~n"),
nil
end.

preferred_ciphers() ->
PreferredCiphers = [
%% Cipher suites (TLS 1.3): TLS_AES_128_GCM_SHA256:TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256
#{cipher => aes_128_gcm, key_exchange => any, mac => aead, prf => sha256},
#{cipher => aes_256_gcm, key_exchange => any, mac => aead, prf => sha384},
#{cipher => chacha20_poly1305, key_exchange => any, mac => aead, prf => sha256},
%% Cipher suites (TLS 1.2): ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:
%% ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:
%% ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384
#{cipher => aes_128_gcm, key_exchange => ecdhe_ecdsa, mac => aead, prf => sha256},
#{cipher => aes_128_gcm, key_exchange => ecdhe_rsa, mac => aead, prf => sha256},
#{cipher => aes_256_gcm, key_exchange => ecdh_ecdsa, mac => aead, prf => sha384},
#{cipher => aes_256_gcm, key_exchange => ecdh_rsa, mac => aead, prf => sha384},
#{cipher => chacha20_poly1305, key_exchange => ecdhe_ecdsa, mac => aead, prf => sha256},
#{cipher => chacha20_poly1305, key_exchange => ecdhe_rsa, mac => aead, prf => sha256},
#{cipher => aes_128_gcm, key_exchange => dhe_rsa, mac => aead, prf => sha256},
#{cipher => aes_256_gcm, key_exchange => dhe_rsa, mac => aead, prf => sha384}
],
ssl:filter_cipher_suites(PreferredCiphers, []).

protocol_versions() ->
case list_to_integer(erlang:system_info(otp_release)) of
Version when Version < 25 ->
['tlsv1.2'];
_ ->
['tlsv1.2', 'tlsv1.3']
end.

preferred_eccs() ->
%% TLS curves: X25519, prime256v1, secp384r1
PreferredECCS = [secp256r1, secp384r1],
ssl:eccs() -- ssl:eccs() -- PreferredECCS.

secure_ssl() ->
case os:getenv("TFLITE_BEAM_UNSAFE_HTTPS") of
nil -> true;
"FALSE" -> false;
"false" -> false;
"nil" -> false;
"NIL" -> false;
_ -> true
end.

https_opts(Hostname) ->
CertFile = certificate_store(),
case {secure_ssl(), is_list(CertFile)} of
{true, true} ->
[
{
ssl, [
{verify, verify_peer},
{cacertfile, CertFile},
{depth, 4},
{ciphers, preferred_ciphers()},
{versions, protocol_versions()},
{eccs, preferred_eccs()},
{reuse_sessions, true},
{server_name_indication, Hostname},
{secure_renegotiate, true},
{customize_hostname_check, [
{match_fun, public_key:pkix_verify_hostname_match_fun(https)}
]}
]
}
];
_ ->
[
{
ssl, [
{verify, verify_none},
{ciphers, preferred_ciphers()},
{versions, protocol_versions()},
{reuse_sessions, true},
{server_name_indication, Hostname},
{secure_renegotiate, true}
]
}
]
end.

set_proxy_if_exists(ProxyEnvVar, ProxyAtom) ->
case os:getenv(ProxyEnvVar) of
false ->
ok;
ProxyUri ->
case uri_string:parse(ProxyUri) of
#{host := Host, port := Port} ->
httpc:set_options([{ProxyAtom, {{Host, Port}, []}}]),
ok;
_ ->
ok
end
end.

do_download(URL, DestionationFilename) ->
{Parsed, HttpOpts} = case uri_string:parse(URL) of
#{scheme := "http"} ->
{ok, nil};
#{scheme := "https", host := Host} ->
ssl:start(),
{ok, https_opts(Host)};
_ ->
{error, "Unsupported URL scheme"}
end,
case Parsed of
ok ->
application:ensure_started(inets),
set_proxy_if_exists("HTTP_PROXY", proxy),
set_proxy_if_exists("HTTPS_PROXY", https_proxy),
set_proxy_if_exists("http_proxy", proxy),
set_proxy_if_exists("https_proxy", https_proxy),
Request = {URL, []},
case httpc:request(get, Request, HttpOpts, [{body_format, binary}]) of
{ok, {{_, 200, _}, _, Body}} ->
case file:write_file(DestionationFilename, Body) of
ok ->
{ok, DestionationFilename};
Err ->
{error, lists:flatten(io_lib:fwrite("Cannot write file ~s: ~p", [DestionationFilename, Err]))}
end;
{ok, {{_, StatusCode, _}, _, Body}} ->
{error, lists:flatten(io_lib:fwrite("Cannot download file from ~s, status code ~p: ~p", [URL, StatusCode, Body]))};
{error, Reason} ->
{error, Reason};
Err ->
{error, lists:flatten(io_lib:fwrite("Cannot download file from ~s: ~p", [URL, Err]))}
end;
error ->
{error, lists:flatten(io_lib:fwrite("Cannot parse URL ~s", [URL]))}
end.

0 comments on commit 94f24f6

Please sign in to comment.