Skip to content

Commit

Permalink
Merge pull request #1148 from JohnSnowLabs/feature/support-for-loadin…
Browse files Browse the repository at this point in the history
…g-datasets-from-dlt-within-databricks

Feature/support for loading datasets from dlt and spark dataframe within databricks
  • Loading branch information
chakravarthik27 authored Dec 2, 2024
2 parents 85b3fc9 + 3bd5ca0 commit 0578ce6
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 3 deletions.
202 changes: 200 additions & 2 deletions langtest/datahandler/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def __init_subclass__(cls, **kwargs):
import pandas as pd

dataset_cls = cls.__name__.replace("Dataset", "").lower()

if dataset_cls in ["deltalivetables"]:
dataset_cls = "delta_live_tables"

if dataset_cls == "pandas":
extensions = [
i.replace("read_", "")
Expand Down Expand Up @@ -192,6 +196,7 @@ def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) ->
raise ValueError(Errors.E025())
self._custom_label = file_path.copy()
self._file_path = file_path.get("data_source")
self.file_ext = file_path.get("source", None)
self._size = None

self.datasets_with_jsonl_extension = []
Expand All @@ -209,7 +214,7 @@ def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) ->
if isinstance(self._file_path, str):
_, self.file_ext = os.path.splitext(self._file_path)

if len(self.file_ext) > 0:
if len(self.file_ext) > 0 and "source" not in file_path:
self.file_ext = self.file_ext.replace(".", "")
elif "source" in file_path:
self.file_ext = file_path["source"]
Expand Down Expand Up @@ -255,7 +260,7 @@ def load(self) -> List[Sample]:
list[Sample]: Loaded text data.
"""

if self.file_ext in ("csv", "huggingface"):
if self.file_ext in ("csv", "huggingface", "spark"):
self.init_cls = self.data_sources[self.file_ext.replace(".", "")](
self._custom_label, task=self.task, **self.kwargs
)
Expand Down Expand Up @@ -1890,3 +1895,196 @@ def renamed_extensions(self, inverted: bool = False) -> Dict[str, str]:
"hdf5": "hdf",
}
return ext_map


class SparkDataset(BaseDataset):
"""Class to handle Spark datasets. Subclass of BaseDataset."""

supported_tasks = [
"ner",
"text-classification",
"question-answering",
"summarization",
"toxicity",
"translation",
"security",
"clinical",
"disinformation",
"sensitivity",
"wino-bias",
"legal",
]

def __init__(self, file_path: Union[str, dict], task: TaskManager, **kwargs) -> None:
"""
Initializes a SparkDataset object.
Args:
file_path (str):
The path to the data file.
task (str):
Task to be evaluated on.
**kwargs:
"""
from pyspark.sql import SparkSession
import string
import random

super().__init__()
self._file_path = file_path
self.task = task

if isinstance(file_path, dict):
self.spark_session: SparkSession = file_path.get("spark_session", None)
self.format = kwargs.get("format", "csv")
self.kwargs = kwargs

if self.spark_session is None:
random_str = "langtest_" + "".join(
random.choices(string.ascii_lowercase, k=5)
)
self.spark_session = SparkSession.builder.appName(random_str).getOrCreate()

def load_raw_data(self) -> List[Dict]:
"""
Load data from a file into raw lists of strings
Returns:
List[Dict]:
parsed file into list of dicts
"""
df = self.spark_session.read.csv(self._file_path, header=True, inferSchema=True)
data = df.collect()
return data

def load_data(self) -> List[Sample]:
"""
Load data from a any file and preprocess it based on the specified task.
Returns:
List[Sample]: A list of preprocessed data samples.
"""
from pyspark.sql import DataFrame

if isinstance(self._file_path.get("data_source", None), DataFrame):
df = self._file_path.get("data_source", [])
column_names = self._file_path.get("column_names", {})

elif isinstance(self._file_path, dict):
self.default_params = self._file_path
self.format = self._file_path.get("format", "csv")
self._file_path = self._file_path.get("data_source", self._file_path)

# df = self.spark_session.read.csv(self._file_path, header=True, inferSchema=True)
if hasattr(self.spark_session.read, self.format):
df: DataFrame = getattr(self.spark_session.read, self.format)(
self._file_path, header=True, inferSchema=True
)
else:
raise ValueError(
Errors.E027(format=self.format)
+ f" for {self._file_path} is not supported."
)

column_names = self.default_params

# remove the data_source key from the column_names dict
if isinstance(column_names, dict):
column_names.pop("data_source")
column_names.pop("source")
column_names.pop("spark_session")
else:
column_names = dict()

# generate the sample data
data = []

for idx, row_data in enumerate(df.toPandas().to_dict(orient="records")):
try:
sample = self.task.create_sample(
row_data,
**column_names,
)
data.append(sample)

except Exception as e:
logging.warning(Warnings.W005(idx=idx, row_data=row_data, e=e))
continue

return data

def export_data(self, data: List[Sample], output_path: str):
"""Exports the data to the corresponding format and saves it to 'output_path'."""
raise NotImplementedError()


class DeltaLiveTablesDataset(BaseDataset):
"""A class to handle datasets from the Delta Live Tables(DLT)."""

supported_tasks = [
"ner",
"text-classification",
"question-answering",
"summarization",
"toxicity",
"translation",
"security",
]

def __init__(self, file_path: str, task: TaskManager, **kwargs) -> None:
"""
Initializes a DltDataset object.
Args:
file_path (str):
The path to the data file.
task (str):
Task to be evaluated on.
**kwargs:
"""
super().__init__()
self._file_path = file_path
self.task = task
self.kwargs = kwargs

def load_raw_data(self) -> List[Dict]:
"""
Load data from a file into raw lists of strings
Returns:
List[Dict]:
parsed file into list of dicts
"""
raise NotImplementedError()

def load_data(self) -> List[Sample]:
"""
Load data from a any file or dlt wrapper and preprocess it based on the specified task.
Returns:
List[Sample]: A list of preprocessed data samples.
"""
from pyspark.sql import DataFrame

if not isinstance(self._file_path, DataFrame):
raise ValueError(
"file_path should be a Spark DataFrame representing the DLT table"
)

df: DataFrame = self._file_path
data = []

for idx, row_data in enumerate(df.toPandas().to_dict(orient="records")):
try:
sample = self.task.create_sample(row_data)
data.append(sample)
except Exception as e:
logging.warning(Warnings.W005(idx=idx, row_data=row_data, e=e))
continue

self.dataset_size = len(data)
return data

def export_data(self, data: List[Sample], output_path: str):
"""Exports the data to the corresponding format and saves it to 'output_path'."""
raise NotImplementedError()
2 changes: 1 addition & 1 deletion langtest/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def create_sample(
elif isinstance(labels, list) or isinstance(labels, str):
labels = ast.literal_eval(labels)
if not isinstance(labels, list):
labels = [labels]
labels = [str(labels)]
labels = [
samples.SequenceLabel(label=label, score=1.0)
if isinstance(label, str)
Expand Down

0 comments on commit 0578ce6

Please sign in to comment.