Operator overload in Python
Operator overload allows you to redefine operator meaning based on your class.
It was the magic of operator overload that we were able to use the operator + to add two numeric objects, as well as to concatenate two string objects.
This functionality in Python, which allows the same operator to have a different meaning depending on the context, is called operator overload.
So what happens when we use them with objects of a user-defined class? Consider the following class:
Example 1 :
class Point: def __init__(self, x, y): self.x = x self.y = y p1 = Point(2, 4) p2 = Point(5, 1) p3 = p1+p2
File "prog.py", line 10, in < module>
p3 = p1+p2
TypeError: unsupported operand type(s) for +: 'Point' and 'Point'
TypeError was generated because Python did not know how to add two Point objects together.
However, the good thing is that we can tell Python this by overloading operators.
Operators overloading are a special functions in python. In general, special functions are functions in which theirs names start and end with a double underscore (__).
The operator overload is obtained by defining a special method in the class definition.
Arithmetic operators
+ Operator
To override the + sign, we will need to implement the __add__() function in the class. A great power implies great responsibilities. We can do what we want, in this function.
Example 1 :
class Point: def __init__(self, x, y): self.x = x self.y = y def __str__(self): return "({0},{1})".format(self.x, self.y) def __add__(self, p): a = self.x + p.x b = self.y + p.y return Point(a, b) p1 = Point(2, 4) p2 = Point(5, 1) p3 = p1+p2 print(p3)
If the expression is of the form x + y, Python interprets it as x .__add __(y). The version of the __add __() method called depends on the type of x and y.
Special operator overload functions in Python
The following table lists the operators and their corresponding special method.
Operator | Function to overload | Expression | Python interpretation |
---|---|---|---|
Addition | __add__ | p1 + p2 | p1.__add__(p2) |
Substraction | __sub__ | p1 - p2 | p1.__sub__(p2) |
Multiplication | __mul__ | p1 * p2 | p1.__mul__(p2) |
Power | __pow__ | p1 ** p2 | p1.__pow__(p2) |
Division | __truediv__ | p1 / p2 | p1.__truediv__(p2) |
Floor division | __floordiv__ | p1 // p2 | p1.__floordiv__(p2) |
the rest (modulo) | __mod__ | p1 % p2 | p1.__mod__(p2) |
Left binary shift | __lshift__ | p1 << p2 | p1.__lshift__(p2) |
Right binary shift | __rshift__ | p1 >> p2 | p1.__rshift__(p2) |
AND binary | __and__ | p1 & p2 | p1.__and__(p2) |
OR binary | __or__ | p1 | p2 | p1.__or__(p2) |
XOR | __xor__ | p1 ^ p2 | p1.__xor__(p2) |
NOT binary | __invert__ | ~p1 | p1.__invert__() |
Overload of comparison operators
Python does not impose any limit on the use of operators overload, so we can overload comparison operators also.
Suppose we want to implement the operaor less than " < " in our Point class.
Compare the magnitude of two points from the origin and return the result. It can be implemented as follows.
Example 2 :
import math class Point: def __init__(self, x, y): self.x = x self.y = y def __lt__(self, p): m_self = math.sqrt((self.x ** 2) + (self.y ** 2)) m2_p = math.sqrt((p.x ** 2) + (p.y ** 2)) return m_self < m2_p p1 = Point(2, 4) p2 = Point(5, 1) if p1 < p2: print("p2 is far from p1")
Likewise, the special functions that we must implement to overload other comparison operators are summarized below.
Operator | Function to overload | Expression | Python interpretation | |
---|---|---|---|---|
Less than | __lt__ | p1 < p2 | p1.__lt__(p2) | |
Less than or equal | __le__ | __le__ | p1 <= p2 | p1.__le__(p2) |
Equal | __eq__ | p1 == p2 | p1.__eq__(p2) | |
Different | __ne__ | p1 != p2 | p1.__ne__(p2) | |
Greater than | __gt__ | p1 > p2 | p1.__gt__(p2) | |
Greater or equal | __ge__ | p1 >= p2 | p1.__ge__(p2) |
Example 3 :
The following example summarizes the use of all of these operators
import math class Point: def __init__(self, x=0, y=0): self.__x = x self.__y = y # Operator + def __add__(self, p): return Point(self.__x + p.__x, self.__y + p.__y) # Operator - def __sub__(self, p): return Point(self.__x - p.__x, self.__y - p.__y) # Operator < def __lt__(self, p): m_self = math.sqrt((self.__x ** 2) + (self.__y ** 2)) m_p = math.sqrt((p.__x ** 2) + (p.__y ** 2)) return m_self < m_p # Operator <= def __le__(self, p): m_self = math.sqrt((self.__x ** 2) + (self.__y ** 2)) m_p = math.sqrt((p.__x ** 2) + (p.__y ** 2)) return m_self <= m_p # Operator > def __gt__(self, p): m_self = math.sqrt((self.__x ** 2) + (self.__y ** 2)) m_p = math.sqrt((p.__x ** 2) + (p.__y ** 2)) return m_self > m_p # Operator >= def __ge__(self, p): m_self = math.sqrt((self.__x ** 2) + (self.__y ** 2)) m_p = math.sqrt((p.__x ** 2) + (p.__y ** 2)) return m_self >= m_p # Operator == def __eq__(self, p): m_self = math.sqrt((self.__x ** 2) + (self.__y ** 2)) m_p = math.sqrt((p.__x ** 2) + (p.__y ** 2)) return m_self == m_p
0 Comment(s)