import boto3
import hashlib
import json
import logging
import urllib.request, urllib.error, urllib.parse
import os
import random
import string

INGRESS_PORTS = os.getenv('PORTS', "80").split(",")
VPC_ID = os.getenv('VPC_ID', "")
REGION = os.getenv('REGION', "us-east-1")

def lambda_handler(event, context):
    global NRANGES

    # Set up logging
    if len(logging.getLogger().handlers) > 0:

    # Set the environment variable DEBUG to 'true' if you want verbose debug details in CloudWatch Logs.
        if os.environ['DEBUG'] == 'true':
    except KeyError:

    # SNS message notification event when the ip ranges document is rotated
    message = json.loads(event['Records'][0]['Sns']['Message'])

    ip_ranges = json.loads(get_ip_groups_json(message['url'], message['md5']))
    cf_ranges = get_ranges_for_service(ip_ranges, SERVICE)

    # Number of security group rules required as per the total range count
    NRANGES = len(cf_ranges) * len(INGRESS_PORTS)

    # Update SGs with the new ranges

def update_security_groups(new_ranges):
    global VPC_ID

    # Creating ec2 boto3 client
    client = boto3.client('ec2', region_name=REGION)

    if VPC_ID == "":
        result = client.describe_vpcs(Filters=[{'Name': 'isDefault', 'Values': ['true']}])
        VPC_ID = result["Vpcs"][0]['VpcId']

    # To number of SGs to update
    range_to_update = get_security_groups_for_update(client, True)
    if len(range_to_update) == 0:
        logging.warning('No groups to {}'.format("update"))
        update_security_group(client, range_to_update, new_ranges)

def update_security_group(client, range_to_update, new_ranges):
    old_prefixes = list()
    to_revoke = {}
    to_add = list()
    final_add = {}
    total = 0

    for each_grp in range_to_update['SecurityGroups']:
        to_revoke[each_grp['GroupId']] = set()

        # If there are any existing ranges in the SG, compare and add it to the revoke list if necessary
        to_revoke_sg = 0
        if len(each_grp['IpPermissions']) > 0:

            for permission in each_grp['IpPermissions']:

                for ip_range in permission['IpRanges']:
                    cidr = ip_range['CidrIp']
                    if new_ranges.count(cidr) == 0:
                        to_revoke_sg += 1

            # Available slots in the SGs are the rules are revoked
            remain_rules = NRULES - (
                    len(each_grp['IpPermissions'][0]['IpRanges']) * len(INGRESS_PORTS)) + to_revoke_sg
  "Total number of rules available in " + each_grp['GroupId'] + " are " + str(remain_rules)))
            final_add[each_grp['GroupId']] = remain_rules
            total += remain_rules

            final_add[each_grp['GroupId']] = NRULES
            total += NRULES

    # Compares and identifies the new range to add from the service ranges list
    for new_range in new_ranges:
        if old_prefixes.count(new_range) == 0:
            to_add.append({'CidrIp': new_range})
  " Range to be added: " + new_range))

    count = 0
    for group in to_revoke:
        if len(to_revoke[group]) > 0:
            count += len(to_revoke[group])
  "Rules that have to be revoked for  " + str(to_revoke[group])))
            revoke_permissions(client, group, to_revoke[group])
  "No rules were identified to be revoked in the security group " + group))"Total number of rules to be revoked in all the security groups are " + str(count * len(INGRESS_PORTS))))"Total number of rules to be added " + str(len(to_add) * len(INGRESS_PORTS))))"Rules to add " + str(to_add)))
    dynamic_rule_add(client, final_add, to_add, total)

def dynamic_rule_add(client, final_add, to_add, total):
    random_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=3))

    if total < (len(to_add) * len(INGRESS_PORTS)):
        security_group = client.create_security_group(
            Description=NAME + "-" + random_str,
            GroupName=NAME + "-" + random_str,
        all_sgs = list(final_add.keys())
        response = client.describe_network_interfaces(
                    'Name': 'group-id',
                    'Values': all_sgs

        final_add[security_group['GroupId']] = NRULES
        all_sgs = list(final_add.keys())

        for each_eni in response['NetworkInterfaces']:

    for each_grp in final_add:
        num_accommodate = final_add[each_grp] // len(INGRESS_PORTS)
        remain_per_grp = final_add[each_grp] % len(INGRESS_PORTS)"Number of rules can security group " + each_grp + " accommodate: " + str(
            num_accommodate * len(INGRESS_PORTS))))

        for each_proto in INGRESS_PORTS:
            permission = {'ToPort': int(each_proto), 'FromPort': int(each_proto), 'IpProtocol': 'tcp'}
            add_params = {
                'ToPort': permission['ToPort'],
                'FromPort': permission['FromPort'],
                'IpRanges': to_add[0:num_accommodate],
                'IpProtocol': permission['IpProtocol']

            client.authorize_security_group_ingress(GroupId=each_grp, IpPermissions=[add_params])
  "Modified " + str(len(to_add[0:num_accommodate])) + " rules on security group " + each_grp +
                          " for the port " + each_proto))
        to_add = to_add[num_accommodate:]

def revoke_permissions(client, group, to_revoke):
    # Revoked rules in each SG for every port number
    for each_proto in INGRESS_PORTS:
        permission = {'ToPort': int(each_proto), 'FromPort': int(each_proto), 'IpProtocol': 'tcp'}
        revoke_params = {
            'ToPort': permission['ToPort'],
            'FromPort': permission['FromPort'],
            'IpRanges': [{'CidrIp': ip_range} for ip_range in to_revoke],
            'IpProtocol': permission['IpProtocol']
        client.revoke_security_group_ingress(GroupId=group, IpPermissions=[revoke_params])"Revoked " + str(len(to_revoke)) + " rules from the security group " + group +
                      " with port " + each_proto))"Ranges revoked from the security group " + group + " are: " + str(to_revoke)))

def create_security_groups(client, response):
    num_sgs = len(response['SecurityGroups'])'Found ' + str(num_sgs) + ' security groups'))
    total_sgs_required = NRANGES // NRULES

    if NRANGES % NRULES > 0:
        total_sgs_required += 1'Total number of security groups required to add all the rules: ' + str(total_sgs_required)))

    to_create_sgs = 0

    if num_sgs < total_sgs_required:
        to_create_sgs = total_sgs_required - num_sgs'Total number of security groups to be created: ' + str(to_create_sgs)))

    # Creates SGs based on the total number of rules that are required to be added
    created_sgs = []

    for sg in range(to_create_sgs):
        random_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=3))
        security_group = client.create_security_group(
            Description=NAME + "-" + random_str,
            GroupName=NAME + "-" + random_str,
        client.create_tags(Resources=created_sgs, Tags=[
                'Key': 'PREFIX_NAME',
                'Value': NAME,
        ], )

    return get_security_groups_for_update(client)

def get_security_groups_for_update(client, create=False):
    filters = [
        {'Name': "tag-key", 'Values': ['PREFIX_NAME']},
        {'Name': "tag-value", 'Values': [NAME]},
        {'Name': "vpc-id", 'Values': [VPC_ID]}

    # Extracting specific security groups with tags
    response = client.describe_security_groups(Filters=filters)

    # Return list of all security groups if none to be created
    if not create:
        return response
        return create_security_groups(client, response)

def get_ip_groups_json(url, expected_hash):"Updating from " + url)
    response = urllib.request.urlopen(url)
    ip_json =
    m = hashlib.md5()
    hash_value = m.hexdigest()
    if hash_value != expected_hash:
        raise Exception('MD5 Mismatch: got ' + hash_value + ' expected ' + expected_hash)
    return ip_json

def get_ranges_for_service(ranges, service):
    service_ranges = list()

    for prefix in ranges['prefixes']:
        if prefix['service'] == service:
            service_ranges.append(prefix['ip_prefix'])'Found ' + service + ' ranges: ' + str(len(service_ranges))))
    return service_ranges

