Skip to content

SQL Rules for Deployment

One of Pilz's most powerful features is generating actual SQL code that can run directly in your database.

From Model to SQL

Trained models generate CASE WHEN expressions. Each spore (leaf) becomes one WHEN clause with its score:

# src/pilz/model/pilz.py:18-32
class Pilz(BaseModel):
    spore: list[Spore]
    target: str | int

    def get_sql(self, res_name: str) -> str:
        return "CASE \n" + "\n".join(self.get_where_sql()) + f"END AS {res_name}"

    def get_where_sql(self) -> list[str]:
        return [
            f"  WHEN {' AND '.join(spore.cut)} THEN CAST({spore.score} AS DOUBLE)"
            for spore in self.spore
        ]

Each spore's cut contains SQL condition strings (e.g., "tenure" <= 12 AND "Contract" = 'Month-to-month'). These are the same filters generated during training via get_left_right_filter() — see Three-Way Splits.

Generated SQL

CASE
    WHEN "tenure" <= 12 AND "Contract" = 'Month-to-month' THEN 0.68
    WHEN "tenure" > 12 AND "Contract" = 'Month-to-month' AND "InternetService" = 'Fiber optic' THEN 0.55
    WHEN "Contract" = 'Two year' THEN 0.12
    WHEN "Contract" = 'One year' AND "tenure" > 24 THEN 0.18
    ELSE 0.35
END

Running in DuckDB

import duckdb

conn = duckdb.connect("customers.db")
conn.execute("CREATE TABLE customers AS SELECT * FROM 'customer_data.csv'")

result = conn.execute("""
    SELECT
        customer_id,
        CASE
            WHEN tenure <= 12 AND Contract = 'Month-to-month' THEN 0.68
            ELSE 0.35
        END AS predicted_churn
    FROM customers
""").fetchdf()

Multiple Trees (Ensemble)

When you have multiple trees, scores are averaged. The Pilze wrapper generates SQL for all target classes:

# src/pilz/model/pilz.py:35-43
class Pilze(BaseModel):
    pilze: dict[str, Pilz]
    target: str

    def get_sql(self) -> dict[str, str]:
        return {
            f"res_{name}": pilz.get_sql(f"res_{name}")
            for name, pilz in self.pilze.items()
        }
SELECT
    customer_id,
    (
        /* Tree 0 */
        CASE WHEN ... THEN ... ELSE 0 END +
        /* Tree 1 */
        CASE WHEN ... THEN ... ELSE 0 END +
        /* Tree 2 */
        CASE WHEN ... THEN ... ELSE 0 END
    ) / 3.0 AS churn_score
FROM customers

Handling Large Rule Sets

The Problem

Some models have hundreds of spores, which can exceed SQL length limits. The max_parallel_where setting controls when batching kicks in:

# src/pilz/model/pilz.py:129-141
def get_split_sql(self, max_parallel_where):
    where_cases = self.get_where_sql()
    return [
        "CASE \n" + "\n".join(batch) + "\nELSE CAST(0 AS DOUBLE)\n END AS res_{i}"
        for i, batch in enumerate(batched(where_cases, max_parallel_where))
    ]
# eval_settings.yaml
max_parallel_where: 1000  # Split if more than 1000 conditions

Batching Strategy

flowchart TB subgraph "Original (500 rules)" O[Full CASE statement] end subgraph "Batched (3 parts)" B1[Part 1: 170 rules] B2[Part 2: 170 rules] B3[Part 3: 160 rules] end subgraph "Combined Result" C[Sum all parts] end O --> B1 & B2 & B3 B1 & B2 & B3 --> C style B1 fill:#ccffcc style B2 fill:#ccffcc style B3 fill:#ccffcc

During evaluation, batched queries are split into separate SQL statements and combined:

# src/pilz/service/darkwing.py:125-164
def get_eval_sr(self, pilze, max_parallel_where):
    for name, pilz in pilze.pilze.items():
        if len(pilz.spore) < max_parallel_where:
            case_sql = pilz.get_sql(name)
            df = self._get_pl_eval_df(col_sql=case_sql)
        else:
            where_cases = pilz.get_split_sql(max_parallel_where)
            df = pl.concat([
                self._get_pl_eval_df(col_sql=sub_sql)
                for sub_sql in where_cases
            ], how="horizontal").sum_horizontal().alias(name).to_frame()

Practical Example

Load the Model

from pilz.model import Pilz, Pilze

# Load a single tree
pilz = Pilz.model_validate_json(open("model/Yes/0.json").read())

# Get SQL
sql = pilz.get_sql("churn_score")
print(sql)

# Or split into batches
batch_sql = pilz.get_split_sql(max_conditions=100)
for i, sql in enumerate(batch_sql):
    print(f"Batch {i}: {sql}")

From Pilze (ensemble)

from pilz.model import Pilze

# Load all trees for a target
pilze = Pilze.load_folder("model/Yes")

# Get averaged SQL
sql = pilze.get_sql("score")

Deployment Patterns

Pattern 1: View Creation

CREATE VIEW customer_churn_scores AS
SELECT
    customer_id,
    CASE
        WHEN contract = 'Month-to-month' AND tenure <= 12 THEN 0.68
        ...
    END AS churn_probability
FROM customers;

Pattern 2: Materialized Table

CREATE TABLE churn_predictions AS
SELECT
    customer_id,
    run_prediction() AS churn_probability,
    CURRENT_TIMESTAMP AS predicted_at
FROM customers;

Pattern 3: Stored Procedure

CREATE OR REPLACE FUNCTION predict_churn(p_customer_id INT)
RETURNS FLOAT
LANGUAGE SQL
AS $$
    SELECT CASE
        WHEN contract = 'Month-to-month' AND tenure <= 12 THEN 0.68
        WHEN contract = 'Two year' THEN 0.12
        ELSE 0.35
    END
    FROM customers
    WHERE id = p_customer_id
$$;

Security Considerations

While Pilz generates raw SQL, for production consider parameterized queries:

# Instead of string concatenation
sql = f"SELECT ... WHERE balance > {user_input}"  # Dangerous!

# Use parameterized queries
sql = "SELECT ... WHERE balance > ?"  # Safe

Input Validation

def safe_prediction(customer_data):
    # Validate inputs before using in SQL
    assert 0 <= customer_data['tenure'] <= 100
    assert customer_data['contract'] in ['Month-to-month', 'One year', 'Two year']

Summary

Feature Method File
Single tree SQL Pilz.get_sql() pilz.py:126
Batched SQL Pilz.get_split_sql() pilz.py:129
Ensemble SQL Pilze.get_sql() pilz.py:155
Evaluation with batching Darkwing.get_eval_sr() darkwing.py:96
Batching threshold max_parallel_where setting settings.py

Next Steps