Skip to content

Make a Classification model with BigQuery ML

Prepare Dataset

screenshot 2024 03 20 at 15.24.53

Create the Model with SQL statement

If you are not familiar with the Iris dataset, please read this post firstly.

We want to make a multi-classification and predict the flower species among Setosa, Versicolor and Virginica based on the measured flower features: sepal_length, sepal_width, petal_length, petal_width.

Here we use the data_split_method = 'random', which is a random selection method to split dataset into training set and testing set. The data_split_eval_fraction = 0.2 indicates the split rate is 0.8:0.2

CREATE OR REPLACE MODEL iris.multiclass_model
OPTIONS(model_type='logistic_reg',
         input_label_cols=['species'],
         data_split_method = 'random',
         data_split_eval_fraction = 0.2 )
AS
SELECT
  sepal_length,
  sepal_width,
  petal_length,
  petal_width,
  species
FROM
  `instruction-415216.iris.iris_table`;

Check Model Performance

Make an evaluation with testing set

SELECT *
FROM ML.EVALUATE(MODEL iris.multiclass_model,
  (
    SELECT
      sepal_length,
      sepal_width,
      petal_length,
      petal_width,
      species
    FROM
      `instruction-415216.iris.iris_table`
    #Selecting 20% of the data for testing
    WHERE
      MOD(ABS(FARM_FINGERPRINT(TO_JSON_STRING(struct(sepal_length, sepal_width, petal_length, petal_width, species)))), 10) < 2  
  )
);

You will see result like this

precisionrecallaccuracyf1_scorelog_lossroc_auc
0.793650793650793720.791666666666666630.833333333333333370.790849673202614340.292551317763255210.975933732933733

Make a prediction and check individual result

SELECT
  sepal_length,
  sepal_width,
  petal_length,
  petal_width,
  species AS actual_species,
  predicted_species
FROM
  ML.PREDICT(MODEL iris.multiclass_model,
    (
      SELECT
        sepal_length,
        sepal_width,
        petal_length,
        petal_width,
        species
      FROM
        `instruction-415216.iris.iris_table`
# Selecting 20% of the data for testing
      WHERE
        MOD(ABS(FARM_FINGERPRINT(TO_JSON_STRING(struct(sepal_length, sepal_width, petal_length, petal_width, species)))), 10) < 2 
    )
  );

Then you will see the result of each prediction and the comparison with the actual label.