SQL Rules for Deployment¶
One of Pilz's most powerful features is generating actual SQL code that can run directly in your database.
From Model to SQL¶
Trained models generate CASE WHEN expressions. Each spore (leaf) becomes one WHEN clause with its score:
# src/pilz/model/pilz.py:18-32
class Pilz(BaseModel):
spore: list[Spore]
target: str | int
def get_sql(self, res_name: str) -> str:
return "CASE \n" + "\n".join(self.get_where_sql()) + f"END AS {res_name}"
def get_where_sql(self) -> list[str]:
return [
f" WHEN {' AND '.join(spore.cut)} THEN CAST({spore.score} AS DOUBLE)"
for spore in self.spore
]
Each spore's cut contains SQL condition strings (e.g., "tenure" <= 12 AND "Contract" = 'Month-to-month'). These are the same filters generated during training via get_left_right_filter() — see Three-Way Splits.
Generated SQL¶
CASE
WHEN "tenure" <= 12 AND "Contract" = 'Month-to-month' THEN 0.68
WHEN "tenure" > 12 AND "Contract" = 'Month-to-month' AND "InternetService" = 'Fiber optic' THEN 0.55
WHEN "Contract" = 'Two year' THEN 0.12
WHEN "Contract" = 'One year' AND "tenure" > 24 THEN 0.18
ELSE 0.35
END
Running in DuckDB¶
import duckdb
conn = duckdb.connect("customers.db")
conn.execute("CREATE TABLE customers AS SELECT * FROM 'customer_data.csv'")
result = conn.execute("""
SELECT
customer_id,
CASE
WHEN tenure <= 12 AND Contract = 'Month-to-month' THEN 0.68
ELSE 0.35
END AS predicted_churn
FROM customers
""").fetchdf()
Multiple Trees (Ensemble)¶
When you have multiple trees, scores are averaged. The Pilze wrapper generates SQL for all target classes:
# src/pilz/model/pilz.py:35-43
class Pilze(BaseModel):
pilze: dict[str, Pilz]
target: str
def get_sql(self) -> dict[str, str]:
return {
f"res_{name}": pilz.get_sql(f"res_{name}")
for name, pilz in self.pilze.items()
}
SELECT
customer_id,
(
/* Tree 0 */
CASE WHEN ... THEN ... ELSE 0 END +
/* Tree 1 */
CASE WHEN ... THEN ... ELSE 0 END +
/* Tree 2 */
CASE WHEN ... THEN ... ELSE 0 END
) / 3.0 AS churn_score
FROM customers
Handling Large Rule Sets¶
The Problem¶
Some models have hundreds of spores, which can exceed SQL length limits. The max_parallel_where setting controls when batching kicks in:
# src/pilz/model/pilz.py:129-141
def get_split_sql(self, max_parallel_where):
where_cases = self.get_where_sql()
return [
"CASE \n" + "\n".join(batch) + "\nELSE CAST(0 AS DOUBLE)\n END AS res_{i}"
for i, batch in enumerate(batched(where_cases, max_parallel_where))
]
Batching Strategy¶
During evaluation, batched queries are split into separate SQL statements and combined:
# src/pilz/service/darkwing.py:125-164
def get_eval_sr(self, pilze, max_parallel_where):
for name, pilz in pilze.pilze.items():
if len(pilz.spore) < max_parallel_where:
case_sql = pilz.get_sql(name)
df = self._get_pl_eval_df(col_sql=case_sql)
else:
where_cases = pilz.get_split_sql(max_parallel_where)
df = pl.concat([
self._get_pl_eval_df(col_sql=sub_sql)
for sub_sql in where_cases
], how="horizontal").sum_horizontal().alias(name).to_frame()
Practical Example¶
Load the Model¶
from pilz.model import Pilz, Pilze
# Load a single tree
pilz = Pilz.model_validate_json(open("model/Yes/0.json").read())
# Get SQL
sql = pilz.get_sql("churn_score")
print(sql)
# Or split into batches
batch_sql = pilz.get_split_sql(max_conditions=100)
for i, sql in enumerate(batch_sql):
print(f"Batch {i}: {sql}")
From Pilze (ensemble)¶
from pilz.model import Pilze
# Load all trees for a target
pilze = Pilze.load_folder("model/Yes")
# Get averaged SQL
sql = pilze.get_sql("score")
Deployment Patterns¶
Pattern 1: View Creation¶
CREATE VIEW customer_churn_scores AS
SELECT
customer_id,
CASE
WHEN contract = 'Month-to-month' AND tenure <= 12 THEN 0.68
...
END AS churn_probability
FROM customers;
Pattern 2: Materialized Table¶
CREATE TABLE churn_predictions AS
SELECT
customer_id,
run_prediction() AS churn_probability,
CURRENT_TIMESTAMP AS predicted_at
FROM customers;
Pattern 3: Stored Procedure¶
CREATE OR REPLACE FUNCTION predict_churn(p_customer_id INT)
RETURNS FLOAT
LANGUAGE SQL
AS $$
SELECT CASE
WHEN contract = 'Month-to-month' AND tenure <= 12 THEN 0.68
WHEN contract = 'Two year' THEN 0.12
ELSE 0.35
END
FROM customers
WHERE id = p_customer_id
$$;
Security Considerations¶
While Pilz generates raw SQL, for production consider parameterized queries:
# Instead of string concatenation
sql = f"SELECT ... WHERE balance > {user_input}" # Dangerous!
# Use parameterized queries
sql = "SELECT ... WHERE balance > ?" # Safe
Input Validation¶
def safe_prediction(customer_data):
# Validate inputs before using in SQL
assert 0 <= customer_data['tenure'] <= 100
assert customer_data['contract'] in ['Month-to-month', 'One year', 'Two year']
Summary¶
| Feature | Method | File |
|---|---|---|
| Single tree SQL | Pilz.get_sql() |
pilz.py:126 |
| Batched SQL | Pilz.get_split_sql() |
pilz.py:129 |
| Ensemble SQL | Pilze.get_sql() |
pilz.py:155 |
| Evaluation with batching | Darkwing.get_eval_sr() |
darkwing.py:96 |
| Batching threshold | max_parallel_where setting |
settings.py |
Next Steps¶
- Settings Reference — All deployment settings
- Best Practices — Production tips
- Troubleshooting — Common issues