In this article we will discuss how to make your custom class Iterable and also create Iterator class for it.

Why should we make a Custom class Iterable ?

Custom classes created by us are by default not Iterable. If we want to iterate over the objects of our custom class then we need to make them Iterable and also create Iterator class for them.
Let’s under stand by example,

Suppose we have a class Team, that basically contains lists of Junior & senior team members i.e.

class Team:
   '''
   Contains List of Junior and senior team members
   '''
   def __init__(self):
       self._juniorMembers = list()
       self._seniorMembers = list()

   def addJuniorMembers(self, members):
       self._juniorMembers += members

   def addSeniorMembers(self, members):
       self._seniorMembers += members

Now let’s create an object of this class and add some junior and senior team members in it i.e.
# Create team class object
team = Team()

# Add name of junior team members
team.addJuniorMembers(['Sam', 'John', 'Marshal'])

# Add name of senior team members
team.addSeniorMembers(['Riti', 'Rani', 'Aadi'])

Till now this class is not Iterable, therefore if we call iter() function on the object of this class i.e.
iter(team)

or try to iterate over this class’s object using for loop i.e.
for member in team:
   print(member)

Then it will throw following error i.e.
TypeError: 'Team' object is not iterable

So, to iterate over the elements of class Team using its object we need to make class Team Iterable.

How to make your Custom Class Iterable | The Iterator Protocol

To make your class Iterable we need to override __iter__() function inside our class i.e.

def __iter__(self):
    pass

This function should return the object of Iterator class associated with this Iterable class.

So, our Iterable team class will be like this,

class Team:
   '''
   Contains List of Junior and senior team members and also overrides the __iter__() function.
   '''
   def __init__(self):
       self._juniorMembers = list()
       self._seniorMembers = list()

   def addJuniorMembers(self, members):
       self._juniorMembers += members

   def addSeniorMembers(self, members):
       self._seniorMembers += members

   def __iter__(self):
       ''' Returns the Iterator object '''
       return TeamIterator(self)


It overrides the __iter__() function. Which returns the object of Iterator class i.e. TeamIterator in our case.

If we call the iter() function on the object of class Team, then it in turn calls the __iter__() function on team object. Which returns the object of Iterator class TeamIterator i.e.

# Get Iterator object from Iterable Team class oject
iterator = iter(team)

print(iterator)

Output:
<__main__.TeamIterator object at 0x01C052D0>

Now let’s see how to create an Iterator class that can Iterate over the contents of this Iterable class Team.

How to create an Iterator class

To create an Iterator class we need to override __next__() function inside our class i.e.

def __next__(self):
    pass

__next__() function should be implemented in such a way that every time we call the function it should return the next element of the associated Iterable class. If there are no more elements then it should raise StopIteration.

Also, Iterator class should be associated with Iterable class object in such a way so that it can access the data members of Iterable class.
Usually inside __iter__() function while creating the object of Iterator class, the Iterable class passes the reference of its current object in Iterator’s constructor. Using which Iterator class object can access the Iterable class’s data members.

Let’s create TeamIterator class for the Iterable Team class i.e.

class TeamIterator:
   ''' Iterator class '''
   def __init__(self, team):
       # Team object reference
       self._team = team
       # member variable to keep track of current index
       self._index = 0

   def __next__(self):
       ''''Returns the next value from team object's lists '''
       if self._index < (len(self._team._juniorMembers) + len(self._team._seniorMembers)) :
           if self._index < len(self._team._juniorMembers): # Check if junior members are fully iterated or not
               result = (self._team._juniorMembers[self._index] , 'junior')
           else:
               result = (self._team._seniorMembers[self._index - len(self._team._juniorMembers)]   , 'senior')
           self._index +=1
           return result
       # End of Iteration
       raise StopIteration


It accepts a Team class object in its constructor and inside __next__() function returns the next element from Team class object’s data members i.e. _juniorMembers & _seniorMembers in a sequence.

Now you can iterate over the contents of Team class using Iterators i.e.

# Create team class object
team = Team()
# Add name of junior team members
team.addJuniorMembers(['Sam', 'John', 'Marshal'])
# Add name of senior team members
team.addSeniorMembers(['Riti', 'Rani', 'Aadi'])

# Get Iterator object from Iterable Team class oject
iterator = iter(team)

# Iterate over the team object using iterator
while True:
    try:
        # Get next element from TeamIterator object using iterator object
        elem = next(iterator)
        # Print the element
        print(elem)
    except StopIteration:
        break

Output:
('Sam', 'junior')
('John', 'junior')
('Marshal', 'junior')
('Riti', 'senior')
('Rani', 'senior')
('Aadi', 'senior')

How did it worked ?

iter() function calls the __iter__() on team objects, which returns the object of TeamIterator. Now on calling next() function on TeamIterator object, it internally calls the __next__() function of the TeamIterator object, which returns the next member details every time. It uses the _index variable to keep the track of already iterated elements. So, every time it is called, it returns the next element and in the end raises StopIteration.

Now as our Team class is Iterable, so we can also iterate over the contents of Team class using for loop too i.e.

# Iterate over team object(Iterable)
for member in team:
    print(member)

Output:
('Sam', 'junior')
('John', 'junior')
('Marshal', 'junior')
('Riti', 'senior')
('Rani', 'senior')
('Aadi', 'senior')

Complete example is as follows:
class TeamIterator:
   ''' Iterator class '''
   def __init__(self, team):
       # Team object reference
       self._team = team
       # member variable to keep track of current index
       self._index = 0

   def __next__(self):
       ''''Returns the next value from team object's lists '''
       if self._index < (len(self._team._juniorMembers) + len(self._team._seniorMembers)) :
           if self._index < len(self._team._juniorMembers): # Check if junior members are fully iterated or not
               result = (self._team._juniorMembers[self._index] , 'junior')
           else:
               result = (self._team._seniorMembers[self._index - len(self._team._juniorMembers)]   , 'senior')
           self._index +=1
           return result
       # End of Iteration
       raise StopIteration



class Team:
   '''
   Contains List of Junior and senior team members and also overrides the __iter__() function.
   '''
   def __init__(self):
       self._juniorMembers = list()
       self._seniorMembers = list()

   def addJuniorMembers(self, members):
       self._juniorMembers += members

   def addSeniorMembers(self, members):
       self._seniorMembers += members

   def __iter__(self):
       ''' Returns the Iterator object '''
       return TeamIterator(self)



def main():
    # Create team class object
    team = Team()
    # Add name of junior team members
    team.addJuniorMembers(['Sam', 'John', 'Marshal'])
    # Add name of senior team members
    team.addSeniorMembers(['Riti', 'Rani', 'Aadi'])

    print('*** Iterate over the team object using for loop ***')

    # Iterate over team object(Iterable)
    for member in team:
        print(member)

    print('*** Iterate over the team object using while loop ***')

    # Get Iterator object from Iterable Team class oject
    iterator = iter(team)

    # Iterate over the team object using iterator
    while True:
        try:
            # Get next element from TeamIterator object using iterator object
            elem = next(iterator)
            # Print the element
            print(elem)
        except StopIteration:
            break

if __name__ == '__main__':
  main()

Output:
*** Iterate over the team object using for loop ***
('Sam', 'junior')
('John', 'junior')
('Marshal', 'junior')
('Riti', 'senior')
('Rani', 'senior')
('Aadi', 'senior')
*** Iterate over the team object using while loop ***
('Sam', 'junior')
('John', 'junior')
('Marshal', 'junior')
('Riti', 'senior')
('Rani', 'senior')
('Aadi', 'senior')

Join a list of 2000+ Programmers for latest Tips & Tutorials