Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scandir and md5 for adlsgen2setup.py #2113

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/backend/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ types-beautifulsoup4
msgraph-sdk==1.1.0
openai-messages-token-helper
python-dotenv
datetime
3 changes: 3 additions & 0 deletions app/backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -434,5 +434,8 @@ yarl==1.9.4
zipp==3.20.0
# via importlib-metadata

# used for adlsgen2setup.py
datetime==4.3.0
# via -r requirements.in
# The following packages are considered to be unsafe in a requirements file:
# setuptools
4 changes: 3 additions & 1 deletion docs/login_and_acl.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ The script performs the following steps:
- Creates example [groups](https://learn.microsoft.com/entra/fundamentals/how-to-manage-groups) listed in the [sampleacls.json](/scripts/sampleacls.json) file.
- Creates a filesystem / container `gptkbcontainer` in the storage account.
- Creates the directories listed in the [sampleacls.json](/scripts/sampleacls.json) file.
- Uploads the sample PDFs referenced in the [sampleacls.json](/scripts/sampleacls.json) file into the appropriate directories.
- Scans the directories for files recursively if you add the option '--scandirs' (default false) cto the argument list (default off) and you don't have '"scandir": false' (default true) below the directory element in the sampleacls.json file.
- Caluclates md5 checksuk of each file refrenced anc compares with existing 'filename.ext.md5' file. Skip upload if same else upload and storenew md5 value in 'filename.ext.md5'
- Uploads the sample PDFs referenced in the [sampleacls.json](/scripts/sampleacls.json) file or files found in the folders with scandir option set to true into the appropriate directories.
- [Recursively sets Access Control Lists (ACLs)](https://learn.microsoft.com/azure/storage/blobs/data-lake-storage-acl-cli) using the information from the [sampleacls.json](/scripts/sampleacls.json) file.

In order to use the sample access control, you need to join these groups in your Microsoft Entra tenant.
Expand Down
140 changes: 123 additions & 17 deletions scripts/adlsgen2setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import argparse
import asyncio
from datetime import datetime
import json
import logging
import os
import hashlib
from typing import Any, Optional

import aiohttp
Expand All @@ -16,7 +18,9 @@
from load_azd_env import load_azd_env

logger = logging.getLogger("scripts")

# Set the logging level for the azure package to DEBUG
logging.getLogger("azure").setLevel(logging.DEBUG)
logging.getLogger('azure.core.pipeline.policies.http_logging_policy').setLevel(logging.DEBUG)

class AdlsGen2Setup:
"""
Expand Down Expand Up @@ -56,7 +60,7 @@ def __init__(
self.data_access_control_format = data_access_control_format
self.graph_headers: Optional[dict[str, str]] = None

async def run(self):
async def run(self, scandirs: bool = False):
async with self.create_service_client() as service_client:
logger.info(f"Ensuring {self.filesystem_name} exists...")
async with service_client.get_file_system_client(self.filesystem_name) as filesystem_client:
Expand All @@ -80,15 +84,17 @@ async def run(self):
)
directories[directory] = directory_client

logger.info("Uploading scanned files...")
if scandirs:
await self.scan_and_upload_directories(directories, filesystem_client)

logger.info("Uploading files...")
for file, file_info in self.data_access_control_format["files"].items():
directory = file_info["directory"]
if directory not in directories:
logger.error(f"File {file} has unknown directory {directory}, exiting...")
return
await self.upload_file(
directory_client=directories[directory], file_path=os.path.join(self.data_directory, file)
)
await self.upload_file(directory_client=directories[directory], file_path=os.path.join(self.data_directory, file))

logger.info("Setting access control...")
for directory, access_control in self.data_access_control_format["directories"].items():
Expand All @@ -100,8 +106,7 @@ async def run(self):
f"Directory {directory} has unknown group {group_name} in access control list, exiting"
)
return
await directory_client.update_access_control_recursive(
acl=f"group:{groups[group_name]}:r-x"
await directory_client.update_access_control_recursive(acl=f"group:{groups[group_name]}:r-x"
)
if "oids" in access_control:
for oid in access_control["oids"]:
Expand All @@ -110,15 +115,114 @@ async def run(self):
for directory_client in directories.values():
await directory_client.close()

async def walk_files(self, src_filepath = "."):
filepath_list = []

#This for loop uses the os.walk() function to walk through the files and directories
#and records the filepaths of the files to a list
for root, dirs, files in os.walk(src_filepath):

#iterate through the files currently obtained by os.walk() and
#create the filepath string for that file and add it to the filepath_list list
root_found: bool = False
for file in files:
#Checks to see if the root is '.' and changes it to the correct current
#working directory by calling os.getcwd(). Otherwise root_path will just be the root variable value.

if not root_found and root == '.':
filepath =os.path.join(os.getcwd() + "/", file)
root_found = True
else:
filepath = os.path.join(root, file)

#Appends filepath to filepath_list if filepath does not currently exist in filepath_list
if filepath not in filepath_list:
filepath_list.append(filepath)

#Return filepath_list
return filepath_list

async def scan_and_upload_directories(self, directories: dict[str, DataLakeDirectoryClient], filesystem_client):
logger.info("Scanning and uploading files from directories recursively...")

for directory, directory_client in directories.items():
directory_path = os.path.join(self.data_directory, directory)
if directory == "/":
continue

# Check if 'scandir' exists and is set to False
if not self.data_access_control_format["directories"][directory].get("scandir", True):
logger.info(f"Skipping directory {directory} as 'scandir' is set to False")
continue

# Check if the directory exists before walking it
if not os.path.exists(directory_path):
logger.warning(f"Directory does not exist: {directory_path}")
continue

# Get all file paths using the walk_files function
file_paths = await self.walk_files(directory_path)

# Upload each file collected
count =0
num = len(file_paths)
for file_path in file_paths:
await self.upload_file(directory_client, file_path, directory)
count=+1
logger.info(f"Uploaded [{count}/{num}] {directory}/{file_path}")

def create_service_client(self):
return DataLakeServiceClient(
account_url=f"https://{self.storage_account_name}.dfs.core.windows.net", credential=self.credentials
)

async def upload_file(self, directory_client: DataLakeDirectoryClient, file_path: str):
with open(file=file_path, mode="rb") as f:
file_client = directory_client.get_file_client(file=os.path.basename(file_path))
await file_client.upload_data(f, overwrite=True)
async def calc_md5(self, path: str) -> str:
hash_md5 = hashlib.md5()
with open(path, "rb") as file:
for chunk in iter(lambda: file.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()

async def get_blob_md5(self, directory_client: DataLakeDirectoryClient, filename: str) -> Optional[str]:
"""
Retrieves the MD5 checksum from the metadata of the specified blob.
"""
file_client = directory_client.get_file_client(filename)
try:
properties = await file_client.get_file_properties()
return properties.metadata.get('md5')
except Exception as e:
logger.error(f"Error getting blob properties for {filename}: {e}")
return None

async def upload_file(self, directory_client: DataLakeDirectoryClient, file_path: str, category: str = ""):
# Calculate MD5 hash once
md5_hash = await self.calc_md5(file_path)

# Get the filename
filename = os.path.basename(file_path)

# Get the MD5 checksum from the blob metadata
blob_md5 = await self.get_blob_md5(directory_client, filename)

# Upload the file if it does not exist or the checksum differs
if blob_md5 is None or md5_hash != blob_md5:
with open(file_path, "rb") as f:
file_client = directory_client.get_file_client(filename)
tmtime = os.path.getmtime(file_path)
last_modified = datetime.fromtimestamp(tmtime).isoformat()
title = os.path.splitext(filename)[0]
metadata = {
"md5": md5_hash,
"category": category,
"updated": last_modified,
"title": title
}
await file_client.upload_data(f, overwrite=True)
await file_client.set_metadata(metadata)
logger.info(f"Uploaded and updated metadata for {filename}")
else:
logger.info(f"No upload needed for {filename}, checksums match")

async def create_or_get_group(self, group_name: str):
group_id = None
Expand All @@ -144,6 +248,7 @@ async def create_or_get_group(self, group_name: str):
# If Unified does not work for you, then you may need the following settings instead:
# "mailEnabled": False,
# "mailNickname": group_name,

}
async with session.post("https://graph.microsoft.com/v1.0/groups", json=group) as response:
content = await response.json()
Expand All @@ -165,19 +270,19 @@ async def main(args: Any):
data_access_control_format = json.load(f)
command = AdlsGen2Setup(
data_directory=args.data_directory,
storage_account_name=os.environ["AZURE_ADLS_GEN2_STORAGE_ACCOUNT"],
filesystem_name="gptkbcontainer",
storage_account_name=os.environ["AZURE_ADLS_GEN2_STORAGE_ACCOUNT"],
filesystem_name=os.environ["AZURE_ADLS_GEN2_FILESYSTEM"],
security_enabled_groups=args.create_security_enabled_groups,
credentials=credentials,
data_access_control_format=data_access_control_format,
)
await command.run()
await command.run(args.scandirs)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Upload sample data to a Data Lake Storage Gen2 account and associate sample access control lists with it using sample groups",
epilog="Example: ./scripts/adlsgen2setup.py ./data --data-access-control ./scripts/sampleacls.json --create-security-enabled-groups <true|false>",
description="Upload data to a Data Lake Storage Gen2 account and associate access control lists with it using sample groups",
epilog="Example: ./scripts/adlsgen2setup.py ./data --data-access-control .azure/${AZURE_ENV_NAME}/docs_acls.json --create-security-enabled-groups <true|false> --scandirs",
)
parser.add_argument("data_directory", help="Data directory that contains sample PDFs")
parser.add_argument(
Expand All @@ -190,9 +295,10 @@ async def main(args: Any):
"--data-access-control", required=True, help="JSON file describing access control for the sample data"
)
parser.add_argument("--verbose", "-v", required=False, action="store_true", help="Verbose output")
parser.add_argument("--scandirs", required=False, action="store_true", help="Scan and upload all files from directories recursively")
args = parser.parse_args()
if args.verbose:
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().setLevel(logging.INFO)

asyncio.run(main(args))
10 changes: 8 additions & 2 deletions scripts/sampleacls.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@
},
"directories": {
"employeeinfo": {
"groups": ["GPTKB_HRTest"]
"groups": ["GPTKB_HRTest"],
"scandir": false
},
"benefitinfo": {
"groups": ["GPTKB_EmployeeTest", "GPTKB_HRTest"]
"groups": ["GPTKB_EmployeeTest", "GPTKB_HRTest"],
"scandir": false
},
"GPT4V_Examples": {
"groups": ["GPTKB_EmployeeTest", "GPTKB_HRTest"],
"scandir": true
},
"/": {
"groups": ["GPTKB_AdminTest"]
Expand Down
Loading