-
Notifications
You must be signed in to change notification settings - Fork 0
/
driver.py
32 lines (23 loc) · 931 Bytes
/
driver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from abc import ABC, abstractmethod
from typing import Union
from pyspark import RDD
from pyspark.sql import SparkSession, DataFrame
class SparkDriver(ABC):
app_name: str
session: SparkSession
@abstractmethod
def reader_fn(self, *args, **kwargs) -> Union[RDD, DataFrame]:
pass
@abstractmethod
def writer_fn(self, *args, **kwards) -> None:
pass
def __init__(self, app_name: str):
self.app_name = app_name
def start(self) -> None:
self.session = SparkSession.builder.appName(self.app_name).getOrCreate()
def stop(self) -> None:
self.session.stop()
def get_data(self, file_name: str, protocol: str, *args, **kwargs) -> Union[RDD, DataFrame]:
return self.reader_fn(file_name, protocol, *args, **kwargs)
def save(self, df: DataFrame, file_name: str, *args, **kwargs) -> None:
return self.writer_fn(df, file_name, args, kwargs)