20

Custom metrics in Keras and how simple they are to use in tensorflow2.2

 4 years ago
source link: https://mc.ai/custom-metrics-in-keras-and-how-simple-they-are-to-use-in-tensorflow2-2/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

So lets get down to it. We first make a custom metric class. While there are more steps to this and they are show in the referenced jupyter notebook , the important thing is to implement the API that integrates with the rest of Keras training and testing workflow. That is as simple as implementing and update_state that takes in the true labels and predictions, a reset_states that re-initializes the metric.

class ConfusionMatrixMetric(tf.keras.metrics.Metric):


    def update_state(self, y_true, y_pred,sample_weight=None):
        self.total_cm.assign_add(self.confusion_matrix(y_true,y_pred))
        return self.total_cm

    def result(self):
        return self.process_confusion_matrix()

    def confusion_matrix(self,y_true, y_pred):
        """
        Make a confusion matrix
        """
        y_pred=tf.argmax(y_pred,1)
        cm=tf.math.confusion_matrix(y_true,y_pred,dtype=tf.float32,num_classes=self.num_classes)
        return cm

    def process_confusion_matrix(self):
        "returns precision, recall and f1 along with overall accuracy"
        cm=self.total_cm
        diag_part=tf.linalg.diag_part(cm)
        precision=diag_part/(tf.reduce_sum(cm,0)+tf.constant(1e-15))
        recall=diag_part/(tf.reduce_sum(cm,1)+tf.constant(1e-15))
        f1=2*precision*recall/(precision+recall+tf.constant(1e-15))
        return precision,recall,f1

In the normal Keras workflow, the method result will be called and it will return a number and nothing else needs to be done. However, in our case we have three tensors for precision, recall and f1 being returned and Keras does not know how to handle this out of the box. This is where the new features of tensorflow 2.2 come in.

Post navigation

← Top Artificial Intelligence Solution Companies

Defining data science, machine learning, and artificial intelligence →

Request for deletion

About

MC.AI – Aggregated news about artificial intelligence

MC.AI collects interesting articles and news about artificial intelligence and related areas. The contributions come from various open sources and are presented here in a collected form.

The copyrights are held by the original authors, the source is indicated with each contribution.

Contributions which should be deleted from this platform can be reported using the appropriate form (within the contribution).

MC.AI is open for direct submissions, we look forward to your contribution!

Search on MC.AI

mc.ai aggregates articles from different sources - copyright remains at original authors


Recommend

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK