Interpretationsfähigkeit - Tabellarische SHAP-Erklärung

Verwenden Sie Kernel SHAP (SHapley Additive exPlanations), um ein tabellarisches Klassifizierungsmodell zu erläutern. Kernel SHAP ist eine modellagnostische Methode, die den Beitrag jedes Features zur Vorhersage eines Modells schätzt. Sie trainieren ein logistisches Regressionsmodell für das Dataset "Adult Census Income" und verwenden dann den SynapseML-Transformator TabularSHAP , um Erläuterungen auf Featureebene zu berechnen.

Voraussetzungen

  • Erstellen Sie ein neues Notizbuch in Ihrem Arbeitsbereich, und fügen Sie es an ein Seehaus an. Weitere Informationen finden Sie unter Erstellen eines Notizbuchs.

SynapseML, PySpark, Pandas und Plotly werden in Fabric Notizbuchumgebungen vorinstalliert. Es ist keine zusätzliche Paketinstallation erforderlich.

Importieren von Paketen und Definieren von Hilfs-UDFs

Fügen Sie in Ihrem Fabric-Notizbuch den folgenden Code in eine Zelle ein, und führen Sie ihn aus. In diesem Schritt werden die erforderlichen Bibliotheken importiert und zwei benutzerdefinierte Funktionen (UDFs) zum späteren Extrahieren von Vektorelementen definiert.

import pyspark
from synapse.ml.explainers import TabularSHAP
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.sql.types import FloatType, ArrayType
from pyspark.sql.functions import col, lit, rand, broadcast, udf
import pandas as pd

vec_access = udf(lambda v, i: float(v[i]), FloatType())
vec2array = udf(lambda vec: vec.toArray().tolist(), ArrayType(FloatType()))

Überprüfen: Führen Sie den folgenden Code in einer neuen Zelle aus. Die Ausgabe TabularSHAP imported successfullysollte angezeigt werden.

print("TabularSHAP imported successfully")
print(f"PySpark version: {pyspark.__version__}")

Laden von Daten und Trainieren eines Klassifizierungsmodells

Laden Sie das Dataset "Adult Census Income" aus Azure Blob Storage, indizieren Sie die Zielbezeichnung, und trainieren Sie eine logistische Regressionspipeline.

df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/AdultCensusIncome.parquet"
)

labelIndexer = StringIndexer(
    inputCol="income", outputCol="label", stringOrderType="alphabetAsc"
).fit(df)
print("Label index assignment: " + str(set(zip(labelIndexer.labels, [0, 1]))))

training = labelIndexer.transform(df).cache()

categorical_features = [
    "workclass",
    "education",
    "marital-status",
    "occupation",
    "relationship",
    "race",
    "sex",
    "native-country",
]
categorical_features_idx = [feat + "_idx" for feat in categorical_features]
categorical_features_enc = [feat + "_enc" for feat in categorical_features]
numeric_features = [
    "age",
    "education-num",
    "capital-gain",
    "capital-loss",
    "hours-per-week",
]

strIndexer = StringIndexer(
    inputCols=categorical_features, outputCols=categorical_features_idx
)
onehotEnc = OneHotEncoder(
    inputCols=categorical_features_idx, outputCols=categorical_features_enc
)
vectAssem = VectorAssembler(
    inputCols=categorical_features_enc + numeric_features, outputCol="features"
)
lr = LogisticRegression(featuresCol="features", labelCol="label", weightCol="fnlwgt")
pipeline = Pipeline(stages=[strIndexer, onehotEnc, vectAssem, lr])
model = pipeline.fit(training)

Überprüfen: Führen Sie die folgende Zelle aus. Die Zeilenanzahl für Schulungsdaten und die Bestätigung von Pipelinephasen sollte angezeigt werden.

print(f"Training rows: {training.count()}")
print(f"Pipeline stages: {[type(s).__name__ for s in model.stages]}")
assert training.count() > 30000, "Dataset should contain over 30,000 rows"
print("Model trained successfully")

# Expected output:
#Training rows: 32561
#Pipeline stages: ['StringIndexerModel', 'OneHotEncoderModel', #'VectorAssembler', 'LogisticRegressionModel']
#Model trained successfully

Wählen Sie Beobachtungen zur Erklärung aus.

Wählen Sie zufällig fünf Beobachtungen aus den bewerteten Schulungsdaten aus. Diese Beobachtungen sind die Instanzen, für die Sie SHAP-Erklärungen erstellen.

explain_instances = (
    model.transform(training).orderBy(rand()).limit(5).repartition(200).cache()
)
display(explain_instances)

Überprüfen: Überprüfen Sie die Beispielgröße.

count = explain_instances.count()
print(f"Explain instances: {count}")
assert count == 5, f"Expected 5 rows, got {count}"
print("Sample selected successfully")

Konfigurieren und Ausführen von TabularSHAP

Erstellen Sie eine TabularSHAP Erklärung und wenden Sie sie auf die ausgewählten Beobachtungen an. Die wichtigsten Parameter sind:

Parameter Beschreibung
inputCols Merkmalsspalten, die das Modell für die Vorhersage verwendet.
outputCol Name der Spalte, die SHAP-Ausgabewerte enthält.
numSamples Anzahl der Perturbationsbeispiele für die Kernel SHAP-Schätzung. Höhere Werte sind genauer, aber langsamer.
model Das trainierte Pipelinemodell zur Erläuterung.
targetCol Die Ausgabespalte des Modells, die erläutert werden soll. In diesem Beispiel ist die Spalte probability.
targetClasses Zu erläuternde Klassenindizes. [1] erklärt nur die Wahrscheinlichkeit der Klasse 1. Verwenden Sie [0, 1], um beide Klassen zu erläutern.
backgroundData Ein Beispiel für Schulungsdaten, die als Referenzverteilung für die Integration von Features verwendet werden.
shap = TabularSHAP(
    inputCols=categorical_features + numeric_features,
    outputCol="shapValues",
    numSamples=5000,
    model=model,
    targetCol="probability",
    targetClasses=[1],
    backgroundData=broadcast(training.orderBy(rand()).limit(100).cache()),
)

shap_df = shap.transform(explain_instances)

Note

Dieser Schritt kann je nach numSamples Clustergröße mehrere Minuten dauern. Mit numSamples=5000 dauert es bei fünf Beobachtungen auf einem standardmäßigen Fabric Spark-Cluster voraussichtlich 3 bis 10 Minuten.

Prüfen: Prüfen Sie, ob die SHAP-Ausgabespalte vorhanden ist.

assert "shapValues" in shap_df.columns, "shapValues column missing"
print(f"SHAP output columns: {shap_df.columns}")
print("TabularSHAP transform completed")

SHAP-Werte extrahieren

Extrahieren Sie die Wahrscheinlichkeit der Klasse 1 und die SHAP-Werte aus dem Ergebnis-DataFrame. Für jede Beobachtung beginnt der SHAP-Wertevektor mit dem Basiswert (mittlere Ausgabe des Hintergrund-Datasets), gefolgt von einem Wert pro Feature.

shaps = (
    shap_df.withColumn("probability", vec_access(col("probability"), lit(1)))
    .withColumn("shapValues", vec2array(col("shapValues").getItem(0)))
    .select(
        ["shapValues", "probability", "label"] + categorical_features + numeric_features
    )
)

shaps_local = shaps.toPandas()
shaps_local.sort_values("probability", ascending=False, inplace=True, ignore_index=True)
pd.set_option("display.max_colwidth", None)
display(shaps_local)

Überprüfen: Bestätigen Sie die Struktur des pandas-DataFrame.

expected_cols = len(categorical_features) + len(numeric_features) + 3
print(f"DataFrame shape: {shaps_local.shape}")
print(f"Expected columns: {expected_cols}, Actual: {shaps_local.shape[1]}")
assert shaps_local.shape == (5, expected_cols), f"Unexpected shape: {shaps_local.shape}"
print("SHAP values extracted successfully")

Visualisieren von SHAP-Werten

Erstellen Sie ein Balkendiagramm für jede Beobachtung, die zeigt, wie jedes Feature zur vorhergesagten Wahrscheinlichkeit beiträgt.

from plotly.subplots import make_subplots
import plotly.graph_objects as go

features = categorical_features + numeric_features
features_with_base = ["Base"] + features

rows = shaps_local.shape[0]

fig = make_subplots(
    rows=rows,
    cols=1,
    subplot_titles="Probability: "
    + shaps_local["probability"].apply("{:.2%}".format)
    + "; Label: "
    + shaps_local["label"].astype(str),
)

for index, row in shaps_local.iterrows():
    feature_values = [0] + [row[feature] for feature in features]
    shap_values = row["shapValues"]
    list_of_tuples = list(zip(features_with_base, feature_values, shap_values))
    shap_pdf = pd.DataFrame(list_of_tuples, columns=["name", "value", "shap"])
    fig.add_trace(
        go.Bar(
            x=shap_pdf["name"],
            y=shap_pdf["shap"],
            hovertext="value: " + shap_pdf["value"].astype(str),
        ),
        row=index + 1,
        col=1,
    )

fig.update_yaxes(range=[-1, 1], fixedrange=True, zerolinecolor="black")
fig.update_xaxes(type="category", tickangle=45, fixedrange=True)
fig.update_layout(height=400 * rows, title_text="SHAP explanations")
fig.show()

Überprüfen: Bestätigen Sie, dass das Plotobjekt erstellt wurde.

print(f"Figure traces: {len(fig.data)}")
print(f"Figure height: {fig.layout.height}px")
assert len(fig.data) == 5, f"Expected 5 traces, got {len(fig.data)}"
print("Visualization created successfully")

Interpretieren der Ergebnisse

Jeder Teilplot stellt eine Beobachtung dar. Die Balken zeigen:

  • Basiswert: Die durchschnittliche Modellausgabe über den Hintergrunddatensatz hinweg (Basiswahrscheinlichkeit).
  • Positive SHAP-Werte: Features, die die Vorhersage in Richtung Klasse 1 (Einkommen größer als 50K) bewegen.
  • Negative SHAP-Werte: Features, die die Vorhersage in Richtung Klasse 0 verschieben (Einkommen kleiner oder gleich 50K).

Die Summe des Basiswerts und aller Feature-SHAP-Werte entspricht der vorhergesagten Wahrscheinlichkeit des Modells für diese Beobachtung.

Problembehandlung

Thema Ursache Resolution
OutOfMemoryError während TabularSHAP numSamples ist zu groß für den verfügbaren Arbeitsspeicher. Reduzieren Sie numSamples, beispielsweise auf 1.000, oder erhöhen Sie den Spark-Executor-Speicher.
Die SHAP-Transformation ist langsam Ein hoher Wert numSamples mit vielen Funktionen erhöht die Berechnungszeit. Reduzieren Sie numSamples sie auf 1.000-2.000, um schnellere explorative Ergebnisse zu erzielen. Erhöhung für die abschließende Analyse.
FileNotFoundException für Parkett Der Netzwerkzugriff auf mmlspark.blob.core.windows.net ist blockiert. Stellen Sie sicher, dass Ihr Fabric Arbeitsbereich über ausgehenden Internetzugriff verfügt. Alternativ können Sie das Dataset in Ihr Seehaus hochladen.
shapValues Spalte enthält Nullwerte. Einige Beobachtungen können fehlschlagen, wenn Featurewerte außerhalb der Schulungsverteilung liegen. Prüfen Sie die Eingabemerkmale auf Nullwerte oder unerwartete Werte. Nullwerte aus den Ergebnissen filtern.
display() zeigt keine Ausgabe an. Der Code wird außerhalb einer Fabric Notizbuchumgebung ausgeführt. Verwenden Sie shaps_local.head() oder print(shaps_local) in Standard-Python Umgebungen.

Aufräumen

Wenn Sie das Dataset für dieses Lernprogramm in ein Seehaus hochgeladen haben, entfernen Sie es, um Speicherplatz frei zu machen:

# Remove cached DataFrames from memory
training.unpersist()
explain_instances.unpersist()
print("Cached DataFrames released")