Comprehensive: Spark -> pandas loop -> Spark

A busy, realistic pipeline: filter in Spark, engineer features in pandas (in-place writes, two group-bys merged back, and a per-region loop that subsets, writes, and concats), then back to Spark to join reference data. Shows how a category loop appears – one branch per value fanning into the concat.

Open the full report ↗

Example code

Source: examples/example_cross_engine_pipeline.py

"""Comprehensive cross-engine pipeline: PySpark -> pandas (with a per-category loop) -> PySpark.

A deliberately busy, realistic-shaped pipeline that exercises most of conformare's pandas
and Spark tracking at once:

1.  Start in **PySpark**: filter on a date and on age.
2.  Pull into **pandas** (`toPandas`).
3.  A run of in-place `df["x"] = ...` writes (these collapse into one chain).
4.  **Two separate group-bys** aggregated and merged back onto the frame.
5.  A **loop over a categorical** (region): for each value, take the matching subset, apply
    in-place writes to it, collect the pieces, then `pd.concat` them back together. A loop
    like this is not great practice -- it shows up as one filter+writes branch per category
    fanning into the concat -- but it is common, and this is what it looks like tracked.
6.  Back to **PySpark** via `spark.createDataFrame(...)`.
7.  **Join** onto another Spark dataset.

Row counts are profiled at every step and **distribution histograms** at the milestones
(entry to pandas, each merge, the loop's concat, and the final join), so the report's Node
profiles show real distributions. In the collapsed Node-profiles view each segment shows
the latest histogram it produced, annotated with where it was measured.

To see the per-category branches roll up, open the diagram and turn on "Compress chained
operations"; turn on "Expand operation details" to read every folded write.

Run:  python examples/example_cross_engine_pipeline.py
Then open output/cross_engine_pipeline_report.html.
"""

import datetime as dt
import os
import warnings

import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

import conformare as cf

REGIONS = ["England", "Scotland", "Wales", "Northern Ireland"]


def _customers_pdf(n: int = 320) -> pd.DataFrame:
    base = dt.date(2022, 1, 1)
    rows = []
    for i in range(n):
        rows.append(
            (
                i,
                REGIONS[i % len(REGIONS)],
                18 + (i * 7) % 55,  # age 18..72
                base + dt.timedelta(days=(i * 11) % 900),  # signup_date spread over ~2.5y
                1 + (i % 12),  # tenure (months)
                float(20 + (i * 17) % 280),  # spend
            )
        )
    return pd.DataFrame(
        rows, columns=["customer_id", "region", "age", "signup_date", "tenure", "spend"]
    )


def _regions_pdf() -> pd.DataFrame:
    return pd.DataFrame({"region": REGIONS, "region_manager": ["Alice", "Bob", "Carys", "Dáire"]})


def main(out=None):
    out = out or os.path.join(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
        "output",
        "cross_engine_pipeline_report.html",
    )
    os.makedirs(os.path.dirname(out), exist_ok=True)

    spark = (
        SparkSession.builder.master("local[1]")
        .appName("conformare-cross-engine")
        .config("spark.ui.enabled", "false")
        .config("spark.sql.shuffle.partitions", "4")
        .getOrCreate()
    )
    spark.sparkContext.setLogLevel("ERROR")

    cf.trackSpark()
    cf.trackPandas()
    with warnings.catch_warnings():  # this example is about these opt-ins; silence their warning
        warnings.simplefilter("ignore", cf.ConformareExperimentalWarning)
        cf.trackPandasSeries()
        cf.trackPandas__setitem__()
    # Row counts everywhere, plus distribution histograms at the milestones (op-keyed
    # profilers add to the "*" wildcard). Unknown columns at a given step are skipped, so
    # listing a column that only exists later is harmless.
    cf.set_profiles(
        {
            "*": [cf.rowCount],
            "toPandas": [cf.histogram(columns=["age", "spend", "tenure"])],
            "merge": [cf.histogram(columns=["spend", "region_avg_spend", "band_size"])],
            "concat": [cf.histogram(columns=["spend", "spend_per_tenure", "region_rank"])],
            "join": [cf.histogram(columns="all")],
        }
    )
    cf.describe_process(
        "Cross-engine customer analytics: filter in Spark, engineer features in pandas "
        "(including a per-region loop), then return to Spark and join reference data -- "
        "tracked end to end across both conversion boundaries."
    )

    # 1) PySpark: filter on a date and on age -----------------------------------------
    with cf.describe("Filter recent adults (Spark)", purpose="Scope to recent, adult signups"):
        customers = spark.createDataFrame(_customers_pdf())
        recent = customers.filter(F.col("signup_date") >= dt.date(2022, 7, 1))
        adults = recent.filter(F.col("age") >= 18)

    # 2 + 3) Pull to pandas and derive columns in place -------------------------------
    with cf.describe("Engineer features (pandas)", purpose="Per-customer derived columns"):
        pdf = adults.toPandas()  # Spark -> pandas
        pdf["age_band"] = (pdf["age"] // 10) * 10  # consecutive in-place writes ...
        pdf["spend_per_tenure"] = pdf["spend"] / pdf["tenure"]  # ... form one chain
        pdf["high_value"] = pdf["spend"] >= 150

    # 4) Two separate group-bys, each merged back -------------------------------------
    with cf.describe("Aggregate and merge back", purpose="Region- and band-level features"):
        region_stats = pdf.groupby("region", as_index=False).agg(region_avg_spend=("spend", "mean"))
        pdf = pdf.merge(region_stats, on="region", how="left")
        band_stats = pdf.groupby("age_band", as_index=False).agg(band_size=("customer_id", "count"))
        pdf = pdf.merge(band_stats, on="age_band", how="left")

    # 5) Loop over a categorical: subset -> in-place writes -> concat back together ----
    #    Not ideal (a groupby/transform would be cleaner), but common. Each iteration is
    #    its own filter + writes branch; pd.concat fans them back into one frame.
    with cf.describe("Per-region loop (subset, write, concat)", purpose="Within-region ranking"):
        parts = []
        for region in sorted(pdf["region"].unique()):
            sub = pdf[pdf["region"] == region]  # subset matching the categorical
            sub["region_rank"] = sub["spend"].rank(ascending=False)  # write on the subset
            sub["above_region_avg"] = sub["spend"] > sub["region_avg_spend"]
            parts.append(sub)
        combined = pd.concat(parts, ignore_index=True)  # stitch the pieces back together

    # 6 + 7) Back to Spark and join reference data ------------------------------------
    with cf.describe("Persist and enrich (Spark)", purpose="Return to Spark, attach managers"):
        sdf = spark.createDataFrame(combined)  # pandas -> Spark
        regions = spark.createDataFrame(_regions_pdf())
        final = sdf.join(regions, on="region", how="left")

    html = cf.to_html(out, title="Cross-engine pipeline (Spark -> pandas loop -> Spark)")

    # Connectivity: every node should be reachable from the roots (one pipeline).
    model = cf.build_model(cf.store)
    adj = {}
    for e in model["edges"]:
        adj.setdefault(e["source"], []).append(e["target"])
    roots = [
        n["id"] for n in model["nodes"] if not any(e["target"] == n["id"] for e in model["edges"])
    ]
    seen, stack = set(), list(roots)
    while stack:
        node = stack.pop()
        if node not in seen:
            seen.add(node)
            stack += adj.get(node, [])

    ops = [e.op for e in cf.lineage() if e.kind == "op"]
    profiled = sum(1 for n in model["nodes"] if n.get("histograms"))
    rows = final.count()
    cf.restore()
    spark.stop()
    print(f"wrote {out} ({len(html):,} bytes)")
    print(f"  ops        : {ops}")
    print(f"  setitem    : {sum(1 for o in ops if o == 'setitem')}")
    print(f"  concat     : {sum(1 for o in ops if o == 'concat')}")
    print(f"  histograms : {profiled} nodes carry a distribution profile")
    print(f"  connected  : {len(seen)}/{model['stats']['nodes']} nodes reachable from roots")
    print(f"  final rows : {rows}")
    return out


if __name__ == "__main__":
    main()

Output report

Open in a new tab ↗


This site uses Just the Docs, a documentation theme for Jekyll.