In this article we will discuss how to make your custom class Iterable and also create Iterator class for it.
Why should we make a Custom class Iterable ?
Custom classes created by us are by default not Iterable. If we want to iterate over the objects of our custom class then we need to make them Iterable and also create Iterator class for them.
Let’s under stand by example,
Suppose we have a class Team, that basically contains lists of Junior & senior team members i.e.
class Team: ''' Contains List of Junior and senior team members ''' def __init__(self): self._juniorMembers = list() self._seniorMembers = list() def addJuniorMembers(self, members): self._juniorMembers += members def addSeniorMembers(self, members): self._seniorMembers += members
Now let’s create an object of this class and add some junior and senior team members in it i.e.
# Create team class object team = Team() # Add name of junior team members team.addJuniorMembers(['Sam', 'John', 'Marshal']) # Add name of senior team members team.addSeniorMembers(['Riti', 'Rani', 'Aadi'])
Till now this class is not Iterable, therefore if we call iter() function on the object of this class i.e.
iter(team)
or try to iterate over this class’s object using for loop i.e.
Frequently Asked:
for member in team: print(member)
Then it will throw following error i.e.
TypeError: 'Team' object is not iterable
So, to iterate over the elements of class Team using its object we need to make class Team Iterable.
How to make your Custom Class Iterable | The Iterator Protocol
To make your class Iterable we need to override __iter__() function inside our class i.e.
def __iter__(self): pass
This function should return the object of Iterator class associated with this Iterable class.
So, our Iterable team class will be like this,
class Team: ''' Contains List of Junior and senior team members and also overrides the __iter__() function. ''' def __init__(self): self._juniorMembers = list() self._seniorMembers = list() def addJuniorMembers(self, members): self._juniorMembers += members def addSeniorMembers(self, members): self._seniorMembers += members def __iter__(self): ''' Returns the Iterator object ''' return TeamIterator(self)
It overrides the __iter__() function. Which returns the object of Iterator class i.e. TeamIterator in our case.
If we call the iter() function on the object of class Team, then it in turn calls the __iter__() function on team object. Which returns the object of Iterator class TeamIterator i.e.
# Get Iterator object from Iterable Team class oject iterator = iter(team) print(iterator)
Output:
<__main__.TeamIterator object at 0x01C052D0>
Now let’s see how to create an Iterator class that can Iterate over the contents of this Iterable class Team.
How to create an Iterator class
To create an Iterator class we need to override __next__() function inside our class i.e.
def __next__(self): pass
__next__() function should be implemented in such a way that every time we call the function it should return the next element of the associated Iterable class. If there are no more elements then it should raise StopIteration.
Also, Iterator class should be associated with Iterable class object in such a way so that it can access the data members of Iterable class.
Usually inside __iter__() function while creating the object of Iterator class, the Iterable class passes the reference of its current object in Iterator’s constructor. Using which Iterator class object can access the Iterable class’s data members.
Let’s create TeamIterator class for the Iterable Team class i.e.
class TeamIterator: ''' Iterator class ''' def __init__(self, team): # Team object reference self._team = team # member variable to keep track of current index self._index = 0 def __next__(self): ''''Returns the next value from team object's lists ''' if self._index < (len(self._team._juniorMembers) + len(self._team._seniorMembers)) : if self._index < len(self._team._juniorMembers): # Check if junior members are fully iterated or not result = (self._team._juniorMembers[self._index] , 'junior') else: result = (self._team._seniorMembers[self._index - len(self._team._juniorMembers)] , 'senior') self._index +=1 return result # End of Iteration raise StopIteration
It accepts a Team class object in its constructor and inside __next__() function returns the next element from Team class object’s data members i.e. _juniorMembers & _seniorMembers in a sequence.
Now you can iterate over the contents of Team class using Iterators i.e.
# Create team class object team = Team() # Add name of junior team members team.addJuniorMembers(['Sam', 'John', 'Marshal']) # Add name of senior team members team.addSeniorMembers(['Riti', 'Rani', 'Aadi']) # Get Iterator object from Iterable Team class oject iterator = iter(team) # Iterate over the team object using iterator while True: try: # Get next element from TeamIterator object using iterator object elem = next(iterator) # Print the element print(elem) except StopIteration: break
Output:
('Sam', 'junior') ('John', 'junior') ('Marshal', 'junior') ('Riti', 'senior') ('Rani', 'senior') ('Aadi', 'senior')
How did it worked ?
iter() function calls the __iter__() on team objects, which returns the object of TeamIterator. Now on calling next() function on TeamIterator object, it internally calls the __next__() function of the TeamIterator object, which returns the next member details every time. It uses the _index variable to keep the track of already iterated elements. So, every time it is called, it returns the next element and in the end raises StopIteration.
Now as our Team class is Iterable, so we can also iterate over the contents of Team class using for loop too i.e.
# Iterate over team object(Iterable) for member in team: print(member)
Output:
('Sam', 'junior') ('John', 'junior') ('Marshal', 'junior') ('Riti', 'senior') ('Rani', 'senior') ('Aadi', 'senior')
Complete example is as follows:
class TeamIterator: ''' Iterator class ''' def __init__(self, team): # Team object reference self._team = team # member variable to keep track of current index self._index = 0 def __next__(self): ''''Returns the next value from team object's lists ''' if self._index < (len(self._team._juniorMembers) + len(self._team._seniorMembers)) : if self._index < len(self._team._juniorMembers): # Check if junior members are fully iterated or not result = (self._team._juniorMembers[self._index] , 'junior') else: result = (self._team._seniorMembers[self._index - len(self._team._juniorMembers)] , 'senior') self._index +=1 return result # End of Iteration raise StopIteration class Team: ''' Contains List of Junior and senior team members and also overrides the __iter__() function. ''' def __init__(self): self._juniorMembers = list() self._seniorMembers = list() def addJuniorMembers(self, members): self._juniorMembers += members def addSeniorMembers(self, members): self._seniorMembers += members def __iter__(self): ''' Returns the Iterator object ''' return TeamIterator(self) def main(): # Create team class object team = Team() # Add name of junior team members team.addJuniorMembers(['Sam', 'John', 'Marshal']) # Add name of senior team members team.addSeniorMembers(['Riti', 'Rani', 'Aadi']) print('*** Iterate over the team object using for loop ***') # Iterate over team object(Iterable) for member in team: print(member) print('*** Iterate over the team object using while loop ***') # Get Iterator object from Iterable Team class oject iterator = iter(team) # Iterate over the team object using iterator while True: try: # Get next element from TeamIterator object using iterator object elem = next(iterator) # Print the element print(elem) except StopIteration: break if __name__ == '__main__': main()
Output:
*** Iterate over the team object using for loop *** ('Sam', 'junior') ('John', 'junior') ('Marshal', 'junior') ('Riti', 'senior') ('Rani', 'senior') ('Aadi', 'senior') *** Iterate over the team object using while loop *** ('Sam', 'junior') ('John', 'junior') ('Marshal', 'junior') ('Riti', 'senior') ('Rani', 'senior') ('Aadi', 'senior')