From c87711c3d372a7200e373bd205dac5838465aaa2 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Wed, 1 Nov 2023 20:00:34 -0700 Subject: [PATCH] Script for comparison of CoCa zero-shot to open_clip implementation ghstack-source-id: d89ac55d9ec0d26e0981b7e2fa94ea85de046088 Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/507 --- ...a_ViT-L-14_en_zeroshot_classification.json | 1 + examples/coca/data.py | 1094 +++++++++++++++++ examples/coca/eval.py | 232 ++++ examples/coca/models.txt | 1 + examples/coca/webdatasets.txt | 1 + torchmultimodal/models/coca/coca_model.py | 61 + .../models/coca/multimodal_decoder.py | 4 + .../modules/encoders/vision_transformer.py | 4 + .../modules/layers/patch_embedding.py | 11 +- torchmultimodal/modules/layers/transformer.py | 57 +- 10 files changed, 1462 insertions(+), 4 deletions(-) create mode 100644 examples/coca/benchmark_imagenet1k_laion2b_s13b_b90k_coca_ViT-L-14_en_zeroshot_classification.json create mode 100644 examples/coca/data.py create mode 100644 examples/coca/eval.py create mode 100644 examples/coca/models.txt create mode 100644 examples/coca/webdatasets.txt diff --git a/examples/coca/benchmark_imagenet1k_laion2b_s13b_b90k_coca_ViT-L-14_en_zeroshot_classification.json b/examples/coca/benchmark_imagenet1k_laion2b_s13b_b90k_coca_ViT-L-14_en_zeroshot_classification.json new file mode 100644 index 000000000..3adabfbc2 --- /dev/null +++ b/examples/coca/benchmark_imagenet1k_laion2b_s13b_b90k_coca_ViT-L-14_en_zeroshot_classification.json @@ -0,0 +1 @@ +{"dataset": "wds/imagenet1k", "model": "coca_ViT-L-14", "pretrained": "laion2b_s13b_b90k", "task": "zeroshot_classification", "metrics": {"acc1": 0.7564, "acc5": 0.94286, "mean_per_class_recall": 0.75652}, "language": "en"} diff --git a/examples/coca/data.py b/examples/coca/data.py new file mode 100644 index 000000000..4eb3b0a8f --- /dev/null +++ b/examples/coca/data.py @@ -0,0 +1,1094 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# File taken from https://github.com/mlfoundations/open_clip/ + +imagenet_classnames = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "rooster", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite (bird of prey)", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder rattlesnake", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peafowl", + "quail", + "partridge", + "african grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern bird", + "crane bird", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany dog", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniel", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael dog", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres dog", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland dog", + "Great Pyrenees dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "brussels griffon", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog (xoloitzcuintli)", + "grey wolf", + "Alaskan tundra wolf", + "red wolf or maned wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket insect", + "stick insect", + "cockroach", + "praying mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral butterfly", + "ringlet butterfly", + "monarch butterfly", + "small white butterfly", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel horse", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram (adult male sheep)", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala (antelope)", + "gazelle", + "arabian camel", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi monkey", + "Geoffroy's spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek fish", + "eel", + "silver salmon", + "rock beauty fish", + "clownfish", + "sturgeon", + "gar fish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "trash can", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster / handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military hat (bearskin or shako)", + "beer bottle", + "beer glass", + "bell tower", + "baby bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "hunting bow", + "bow tie", + "brass memorial plaque", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "cardboard box / carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "storage chest", + "chiffonier", + "bell or wind chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "spiral or coil", + "combination lock", + "computer keyboard", + "candy store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "construction crane", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire truck", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask or respirator", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "radiator grille", + "grocery store", + "guillotine", + "hair clip", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "combine harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "gymnastic horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "carved pumpkin", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "music speaker", + "loupe magnifying glass", + "sawmill", + "magnetic compass", + "messenger bag", + "mailbox", + "tights", + "one-piece bathing suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine cabinet", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "ford model t", + "modem", + "monastery", + "monitor", + "moped", + "mortar and pestle", + "graduation cap", + "mosque", + "mosquito net", + "vespa", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "metal nail", + "neck brace", + "necklace", + "baby pacifier", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "pipe organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "product packet / packaging", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "railroad car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "drink pitcher", + "block plane", + "planetarium", + "plastic bag", + "plate rack", + "farm plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "soda bottle", + "plant pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "missile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "fishing casting reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler measuring stick", + "sneaker", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT monitor", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji screen / room divider", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "balaclava ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "keyboard space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglasses", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swim trunks / shorts", + "swing", + "electrical switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "hot tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vaulted or arched ceiling", + "velvet fabric", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "hair wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "airplane wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "sailboat", + "yurt", + "website", + "comic book", + "crossword", + "traffic or street sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "popsicle", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potatoes", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith apple", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "cherimoya (custard apple)", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "tea cup", + "eggnog", + "mountain", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "sandbar", + "beach", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star fungus", + "hen of the woods mushroom", + "bolete", + "corn cob", + "toilet paper", +] + + +openai_imagenet_template = [ + lambda c: f"a bad photo of a {c}.", + lambda c: f"a photo of many {c}.", + lambda c: f"a sculpture of a {c}.", + lambda c: f"a photo of the hard to see {c}.", + lambda c: f"a low resolution photo of the {c}.", + lambda c: f"a rendering of a {c}.", + lambda c: f"graffiti of a {c}.", + lambda c: f"a bad photo of the {c}.", + lambda c: f"a cropped photo of the {c}.", + lambda c: f"a tattoo of a {c}.", + lambda c: f"the embroidered {c}.", + lambda c: f"a photo of a hard to see {c}.", + lambda c: f"a bright photo of a {c}.", + lambda c: f"a photo of a clean {c}.", + lambda c: f"a photo of a dirty {c}.", + lambda c: f"a dark photo of the {c}.", + lambda c: f"a drawing of a {c}.", + lambda c: f"a photo of my {c}.", + lambda c: f"the plastic {c}.", + lambda c: f"a photo of the cool {c}.", + lambda c: f"a close-up photo of a {c}.", + lambda c: f"a black and white photo of the {c}.", + lambda c: f"a painting of the {c}.", + lambda c: f"a painting of a {c}.", + lambda c: f"a pixelated photo of the {c}.", + lambda c: f"a sculpture of the {c}.", + lambda c: f"a bright photo of the {c}.", + lambda c: f"a cropped photo of a {c}.", + lambda c: f"a plastic {c}.", + lambda c: f"a photo of the dirty {c}.", + lambda c: f"a jpeg corrupted photo of a {c}.", + lambda c: f"a blurry photo of the {c}.", + lambda c: f"a photo of the {c}.", + lambda c: f"a good photo of the {c}.", + lambda c: f"a rendering of the {c}.", + lambda c: f"a {c} in a video game.", + lambda c: f"a photo of one {c}.", + lambda c: f"a doodle of a {c}.", + lambda c: f"a close-up photo of the {c}.", + lambda c: f"a photo of a {c}.", + lambda c: f"the origami {c}.", + lambda c: f"the {c} in a video game.", + lambda c: f"a sketch of a {c}.", + lambda c: f"a doodle of the {c}.", + lambda c: f"a origami {c}.", + lambda c: f"a low resolution photo of a {c}.", + lambda c: f"the toy {c}.", + lambda c: f"a rendition of the {c}.", + lambda c: f"a photo of the clean {c}.", + lambda c: f"a photo of a large {c}.", + lambda c: f"a rendition of a {c}.", + lambda c: f"a photo of a nice {c}.", + lambda c: f"a photo of a weird {c}.", + lambda c: f"a blurry photo of a {c}.", + lambda c: f"a cartoon {c}.", + lambda c: f"art of a {c}.", + lambda c: f"a sketch of the {c}.", + lambda c: f"a embroidered {c}.", + lambda c: f"a pixelated photo of a {c}.", + lambda c: f"itap of the {c}.", + lambda c: f"a jpeg corrupted photo of the {c}.", + lambda c: f"a good photo of a {c}.", + lambda c: f"a plushie {c}.", + lambda c: f"a photo of the nice {c}.", + lambda c: f"a photo of the small {c}.", + lambda c: f"a photo of the weird {c}.", + lambda c: f"the cartoon {c}.", + lambda c: f"art of the {c}.", + lambda c: f"a drawing of the {c}.", + lambda c: f"a photo of the large {c}.", + lambda c: f"a black and white photo of a {c}.", + lambda c: f"the plushie {c}.", + lambda c: f"a dark photo of a {c}.", + lambda c: f"itap of a {c}.", + lambda c: f"graffiti of the {c}.", + lambda c: f"a toy {c}.", + lambda c: f"itap of my {c}.", + lambda c: f"a photo of a cool {c}.", + lambda c: f"a photo of a small {c}.", + lambda c: f"a tattoo of the {c}.", +] diff --git a/examples/coca/eval.py b/examples/coca/eval.py new file mode 100644 index 000000000..45e962f3c --- /dev/null +++ b/examples/coca/eval.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import random +from pathlib import Path + +import numpy as np +import torch +from clip_benchmark.datasets.builder import build_dataset +from data import imagenet_classnames, openai_imagenet_template +from torch.utils.data import DataLoader +from torchmultimodal import _PATH_MANAGER +from torchmultimodal.models.coca.coca_model import coca_vit_l_14_open_clip +from torchmultimodal.transforms.clip_transform import ( + CLIPImageTransform, + CLIPTextTransform, +) +from tqdm import tqdm + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}): {}".format(args.rank, args.dist_url), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def get_args_parser(): + parser = argparse.ArgumentParser("CoCa eval", add_help=False) + parser.add_argument("--device", default=0, type=int, help="GPU id to use") + parser.add_argument( + "--world-size", default=1, type=int, help="number of distributed processes" + ) + parser.add_argument( + "--dist-url", default="env://", help="init url for distributed training" + ) + parser.add_argument("--seed", default=42, type=int, help="random seed") + parser.add_argument( + "--pretrained", default="", help="path to pretrained checkpoint" + ) + parser.add_argument("--output_dir", default=".", help="path to save outputs") + return parser + + +@torch.no_grad +def _zero_shot_classifier(model, device, text_transform, *args, **kwargs): + zeroshot_weights = [] + for classname in tqdm(imagenet_classnames): + texts = text_transform( + [template(classname) for template in openai_imagenet_template] + ) + texts = texts.to(device) + class_embeddings = model.encode_text(texts) + class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) + class_embedding = class_embeddings.mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) + return zeroshot_weights + + +def _accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [ + float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) + for k in topk + ] + + +@torch.no_grad +def run_imagenet_zero_shot(model, dataloader, device, text_transform, *args, **kwargs): + print("Starting ImageNet Zero-Shot Eval") + print("Building classifier") + classifier = _zero_shot_classifier(model, device, text_transform) + print("Classifier built") + top1, top5, n = 0.0, 0.0, 0.0 + i = 0 + for sample in tqdm(dataloader): + i = i + 1 + images, target = sample + images = images.to(device) + target = target.to(device) + + image_features = model.encode_image(images) + image_features /= image_features.norm(dim=-1, keepdim=True) + logits = 100.0 * image_features @ classifier + + # measure accuracy + acc1, acc5 = _accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = top1 / n + top5 = top5 / n + results = {} + results["imagenet-zeroshot-test-top1"] = top1 + results["imagenet-zeroshot-test-top5"] = top5 + return results + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [ + float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) + for k in topk + ] + + +def run_open_clip_zero_shot(): + baseline_command = """\ + clip_benchmark eval --pretrained_model \"models.txt\" \ + --dataset \"webdatasets.txt\" \ + --dataset_root \"https://huggingface.co/datasets/clip-benchmark/wds_{dataset_cleaned}/tree/main\" \ + --output "benchmark_{dataset}_{pretrained}_{model}_{language}_{task}.json" \ + """ + print("Running open_clip zero-shot on imagenet") + os.system(baseline_command) + + +def run_torchmultimodal_zero_shot(device): + print("defining transform") + transform = CLIPImageTransform(is_train=False) + print("building dataset") + tmm_dataset = build_dataset( + dataset_name="wds/imagenet1k", + root="https://huggingface.co/datasets/clip-benchmark/wds_imagenet1k/tree/main", + transform=transform, + split="test", + annotation_file="", + download=True, + language="en", + task="zeroshot_classification", + custom_template_file=None, + custom_classname_file=None, + wds_cache_dir=None, + ) + dataloader = DataLoader(tmm_dataset, batch_size=64) + + # Build the model + print("building model") + model = coca_vit_l_14_open_clip() + model.to(device) + if args.pretrained: + print("loading checkpoint") + with _PATH_MANAGER.open(args.pretrained, "rb") as f: + weights = torch.load(f) + model.load_state_dict(weights) + model.eval() + + text_transform = CLIPTextTransform() + tmm_out = run_imagenet_zero_shot(model, dataloader, device, text_transform) + return tmm_out + + +def main(args): + # Init distributed mode + init_distributed_mode(args) + + device = torch.device(args.device) + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = 0 + + # fix the seed for reproducibility + seed = args.seed + rank + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + torch.use_deterministic_algorithms(True, warn_only=True) + + tmm_out = run_torchmultimodal_zero_shot(device) + print(f"TorchMultimodal zero-shot accuracy: {tmm_out}") + run_open_clip_zero_shot() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("CoCa zero-shot eval", parents=[get_args_parser()]) + args = parser.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/examples/coca/models.txt b/examples/coca/models.txt new file mode 100644 index 000000000..da6023e3b --- /dev/null +++ b/examples/coca/models.txt @@ -0,0 +1 @@ +coca_ViT-L-14,laion2b_s13b_b90k diff --git a/examples/coca/webdatasets.txt b/examples/coca/webdatasets.txt new file mode 100644 index 000000000..c0d53401f --- /dev/null +++ b/examples/coca/webdatasets.txt @@ -0,0 +1 @@ +wds/imagenet1k diff --git a/torchmultimodal/models/coca/coca_model.py b/torchmultimodal/models/coca/coca_model.py index 5eab2b10e..827625cff 100644 --- a/torchmultimodal/models/coca/coca_model.py +++ b/torchmultimodal/models/coca/coca_model.py @@ -65,6 +65,30 @@ def __init__( self.vision_pooler = vision_pooler self.vision_proj = vision_proj + def _encode_image(self, images, normalize=True): + image_out = self.vision_encoder(images).last_hidden_state + image_out = self.vision_pooler(image_out) + image_first_token = image_out[:, 0] + image_out = self.vision_proj(image_first_token) + image_out = F.normalize(image_out, dim=-1) if normalize else image_out + return image_out + + def _encode_text(self, text, normalize=True, embed_cls=True): + text = text[:, :-1] if embed_cls else text # make space for CLS token + text_latent, token_emb = self.text_decoder(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize=True): + image_latent = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True, embed_cls=True): + text_latent, _ = self._encode_text( + text, normalize=normalize, embed_cls=embed_cls + ) + return text_latent + def forward( self, images: Tensor, texts: Tensor, text_padding_mask: Optional[Tensor] = None ) -> CoCaModelOutput: @@ -164,6 +188,8 @@ def coca_vit( vision_include_cls_embed: bool = False, # This is different from ViT default vision_drop_path_rate: Optional[float] = None, vision_patch_drop_rate: Optional[Union[float, Tuple[float, float]]] = None, + vision_patch_embedding_has_bias: Optional[bool] = True, + vision_transformer_ln_pre: Optional[bool] = False, # Optional text args pad_idx: Optional[int] = 0, text_embed_cls: bool = True, @@ -179,6 +205,8 @@ def coca_vit( fusion_norm_first: bool = True, fusion_final_layer_norm_eps: Optional[float] = 1e-5, multimodal_output_projection_dim: Optional[int] = None, + fusion_use_extra_mlp: Optional[bool] = False, + fusion_kv_norm: Optional[bool] = False, # Optional attention pooler args cascaded_pooler: bool = True, pooler_n_queries: int = 256, @@ -297,6 +325,8 @@ def coca_vit( include_cls_embed=vision_include_cls_embed, drop_path_rate=vision_drop_path_rate, patch_drop_rate=vision_patch_drop_rate, + patch_embedding_has_bias=vision_patch_embedding_has_bias, + transformer_ln_pre=vision_transformer_ln_pre, ) text_decoder = CoCaTextDecoder( @@ -330,6 +360,8 @@ def coca_vit( layer_norm_eps=fusion_layer_norm_eps, norm_first=fusion_norm_first, final_layer_norm_eps=fusion_final_layer_norm_eps, + use_extra_mlp=fusion_use_extra_mlp, + kv_norm=fusion_kv_norm, ) return CoCaModel( @@ -391,6 +423,35 @@ def coca_vit_l_14(): ) +def coca_vit_l_14_open_clip(): + return coca_vit( + vision_patch_size=14, + vision_n_layer=24, + vision_n_head=16, + vision_dim_feedforward=4096, + vision_include_cls_embed=True, + vocab_size=49408, + num_text_positions=77, + text_hidden_dim=768, + text_n_layer=12, + text_n_head=12, + text_dim_feedforward=3072, + text_output_dim=768, + fusion_n_layer=12, + fusion_n_head=12, + fusion_dim_feedforward=3072, + multimodal_output_projection_dim=49408, + pooler_input_embed_dim=1024, + pooler_output_embed_dim=768, + pooler_n_head=8, + cascaded_pooler=False, + vision_patch_embedding_has_bias=False, + vision_transformer_ln_pre=True, + fusion_use_extra_mlp=True, + fusion_kv_norm=True, + ) + + class CoCaForPretraining(nn.Module): """ CoCa pretraining model class. diff --git a/torchmultimodal/models/coca/multimodal_decoder.py b/torchmultimodal/models/coca/multimodal_decoder.py index 6520b44b5..118d8e7c5 100644 --- a/torchmultimodal/models/coca/multimodal_decoder.py +++ b/torchmultimodal/models/coca/multimodal_decoder.py @@ -53,6 +53,8 @@ def __init__( norm_first: bool = True, final_layer_norm_eps: Optional[float] = 1e-5, visual_embedding_dim: Optional[int] = None, + use_extra_mlp: Optional[bool] = False, + kv_norm: Optional[bool] = False, ): super().__init__() self.transformer_decoder = TransformerDecoder( @@ -67,6 +69,8 @@ def __init__( use_cross_attention=True, final_layer_norm_eps=final_layer_norm_eps, dim_kv=visual_embedding_dim, + use_extra_mlp=use_extra_mlp, + kv_norm=kv_norm, ) if output_dim is not None: self.output_projection = nn.Linear( diff --git a/torchmultimodal/modules/encoders/vision_transformer.py b/torchmultimodal/modules/encoders/vision_transformer.py index 4e170ee6e..475898aed 100644 --- a/torchmultimodal/modules/encoders/vision_transformer.py +++ b/torchmultimodal/modules/encoders/vision_transformer.py @@ -146,7 +146,9 @@ def vision_transformer( final_layer_norm_eps: Optional[float] = 1e-6, norm_first: bool = True, include_cls_embed: bool = True, + patch_embedding_has_bias: bool = True, drop_path_rate: Optional[float] = None, + transformer_ln_pre: bool = False, patch_drop_rate: Optional[Union[float, Tuple[float, float]]] = None, pooler: Optional[nn.Module] = None, ckpt_path: str = None, @@ -184,6 +186,7 @@ def vision_transformer( patch_drop_rate=patch_drop_rate, num_channels=num_channels, include_cls_embed=include_cls_embed, + conv_proj_has_bias=patch_embedding_has_bias, ) transformer_encoder = TransformerEncoder( n_layer=n_layer, @@ -196,6 +199,7 @@ def vision_transformer( norm_first=norm_first, final_layer_norm_eps=final_layer_norm_eps, drop_path_rate=drop_path_rate, + ln_pre=transformer_ln_pre, ) vit = VisionTransformer( embeddings=image_embedding, encoder=transformer_encoder, pooler=pooler diff --git a/torchmultimodal/modules/layers/patch_embedding.py b/torchmultimodal/modules/layers/patch_embedding.py index 4ff48e722..8abfe2d41 100644 --- a/torchmultimodal/modules/layers/patch_embedding.py +++ b/torchmultimodal/modules/layers/patch_embedding.py @@ -49,6 +49,7 @@ def __init__( use_image_masking: bool = False, patch_drop_rate: Optional[Union[float, Tuple[float, float]]] = None, include_cls_embed: bool = True, + conv_proj_has_bias: bool = True, ) -> None: super().__init__() if isinstance(image_size, int): @@ -65,7 +66,11 @@ def __init__( self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) num_patches = num_patches + 1 self.conv_projection = nn.Conv2d( - num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + num_channels, + hidden_size, + kernel_size=patch_size, + stride=patch_size, + bias=conv_proj_has_bias, ) self._init_conv_weights() @@ -90,8 +95,8 @@ def _init_conv_weights(self) -> None: * self.conv_projection.kernel_size[1] ) nn.init.trunc_normal_(self.conv_projection.weight, std=math.sqrt(1 / fan_in)) - assert self.conv_projection.bias is not None - nn.init.zeros_(self.conv_projection.bias) + if self.conv_projection.bias is not None: + nn.init.zeros_(self.conv_projection.bias) def forward( self, diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 47f42d57c..76660ca16 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -186,8 +186,14 @@ def __init__( norm_first: bool = False, final_layer_norm_eps: Optional[float] = None, drop_path_rate: Optional[float] = None, + ln_pre: bool = False, ): super().__init__() + self.ln_pre = ln_pre + if self.ln_pre: + self.pre_layer_norm = nn.LayerNorm(d_model) + else: + self.pre_layer_norm = nn.Identity() if drop_path_rate is not None: drop_rate = [x.item() for x in torch.linspace(0, drop_path_rate, n_layer)] else: @@ -232,6 +238,9 @@ def forward( The last entry in the list is the output from last encoder block before final ln has been applied. """ + if self.ln_pre: + hidden_states = self.pre_layer_norm(hidden_states) + all_hidden_states = [] for layer_module in self.layer: @@ -291,6 +300,8 @@ def __init__( norm_first: bool = False, use_cross_attention: bool = True, dim_kv: Optional[int] = None, + use_extra_mlp: Optional[bool] = False, + kv_norm: Optional[bool] = False, ) -> None: super().__init__() if dim_kv is not None: @@ -319,6 +330,25 @@ def __init__( ) self.cross_attention_layernorm = Fp32LayerNorm(d_model, eps=layer_norm_eps) self.cross_attention_dropout = nn.Dropout(dropout) + self.use_extra_mlp = use_extra_mlp + self.kv_norm = kv_norm + if self.use_extra_mlp: + self.extra_mlp_layernorm = Fp32LayerNorm(d_model, eps=layer_norm_eps) + self.extra_mlp = MLP( + d_model, + d_model, + dim_feedforward, + dropout=dropout, + activation=activation, + ) + self.extra_mlp_dropout = nn.Dropout(dropout) + + if use_cross_attention: + self.ln_1_kv = ( + Fp32LayerNorm(dim_kv, eps=layer_norm_eps) + if self.kv_norm + else nn.Identity() + ) # Feedforward self.feedforward = MLP( @@ -392,6 +422,11 @@ def _feedforward_block(self, hidden_states: Tensor) -> Tensor: h = self.feedforward_dropout(h) return h + def _extra_mlp_block(self, hidden_states: Tensor) -> Tensor: + h = self.extra_mlp(hidden_states) + h = self.extra_mlp_dropout(h) + return h + def _forward_prenorm( self, hidden_states: Tensor, @@ -410,7 +445,13 @@ def _forward_prenorm( past_key_value=past_key_value, use_cache=use_cache, ) - self_attn_output = attn_output + hidden_states + + if self.use_extra_mlp: + # Feedforward + ff_input = self.extra_mlp_layernorm(attn_output) + self_attn_output = attn_output + self._extra_mlp_block(ff_input) + else: + self_attn_output = attn_output + hidden_states # Optional cross-attention if self.use_cross_attention: @@ -421,6 +462,9 @@ def _forward_prenorm( self, "cross_attention_layernorm" ), "Cross-attention layernorm not initialized" cross_attn_input = self.cross_attention_layernorm(self_attn_output) + # KV norm only used as pre-norm + if self.kv_norm: + encoder_hidden_states = self.ln_1_kv(encoder_hidden_states) cross_attn_output = self._cross_attention_block( cross_attn_input, encoder_hidden_states, @@ -456,6 +500,13 @@ def _forward_postnorm( attn_residual = attn_output + hidden_states self_attn_output = self.attention_layernorm(attn_residual) + if self.use_extra_mlp: + # Feedforward + resid_output = attn_output + self._extra_mlp_block(attn_output) + self_attn_output = self.extra_mlp_layernorm(resid_output) + else: + self_attn_output = attn_output + hidden_states + # Optional cross-attention if self.use_cross_attention: if encoder_hidden_states is None: @@ -569,6 +620,8 @@ def __init__( use_cross_attention: bool = True, dim_kv: Optional[int] = None, final_layer_norm_eps: Optional[float] = None, + use_extra_mlp: Optional[bool] = False, + kv_norm: Optional[bool] = False, ): super().__init__() self.layer = nn.ModuleList( @@ -583,6 +636,8 @@ def __init__( norm_first, use_cross_attention, dim_kv, + use_extra_mlp, + kv_norm, ) for i in range(n_layer) ]