import socket
import threading
import time
import struct
import json
import netifaces
import paho.mqtt.client as mqtt

intf='enp2s0'
gridmetersn='21881E00048E_gridmeter'
xlmetersn='21881E00048E_xlmeter'

#---LNXALL private ethernet protocol
#-------------------------------------------------------------------------------------------------------
#---------|---dst mac(6 bytes)---|---src mac(6 bytes)---|---eth type(2 bytes)---|
#---------|---proto type(1 byte)---|---len(2 byte)---|---data---|---crc16(2 bytes)---|---padding---|
#-------------------------------------------------------------------------------------------------------
#--proto type: 0----|---discover data len(2 byte)---|---dev segment data---|
#      dev segment data:---|---dev segment id(1 byte)---|---segment len(1 byte)---|---segment data---|
#          dev segment id: 0-SN 1-capacity 2-power 3-fire ......
#--proto type: 1----|---645 data len(2 byte)---|---645 meter data---|
#--proto type: 2----|---485 data len(2 byte)---|---485 meter data---|

BIG_ENDIAN = 'big'
LITTLE_ENDIAN = 'little'

LNXALL_DISCOVER_INVL=30

#eth type
LNXALL_ETH_TYPE = 0x8120 #ethernet type for lnxall private ethernet protocol

#proto type
LNXALL_DEV = 0x00 #device discover and manage
LNXALL_645METER = 0x01 #meter data in 645 format
LNXALL_485METER = 0x02 #meter data in 485 format
LNXALL_JSONMETER = 0x03 #meter data in json format

#dev segment id
LNXALL_DEV_SEGID_SN = 0x00
LNXALL_DEV_SEGID_CAP = 0x01
LNXALL_DEV_SEGID_PWR = 0x02
LNXALL_DEV_SEGID_FIRE = 0x03

#meter data type
#LNXALL_METER_RAWDATA = 0x00 #indicate meter data as raw protocol data
#LNXALL_METER_JSON = 0x01 #indicate meter data in json format

capacity = 215 #kwh
power = 100 #kw
fire = 0

macaddr=b"\xff\xff\xff\xff\xff\xff"
BROADCAST=b"\xff\xff\xff\xff\xff\xff"

facfilename='/app/config/fac.ini'
MQTT_SERVER_IP='172.18.39.101'
MQTT_SERVER_PORT=3883
MQTT_SERVER_USERNAME='localuser'
MQTT_SERVER_PASSWORD='dywl@galaxy'
metersn=''
bmsmetersn=''
pcssn=''
meterdata={}
gmsn='' #gridmeter gateway sn
combosn='' #combo sn

socks={}
lnxalldevs={}

mqttclient=mqtt.Client('LNXALL_ETHPRO')

gmflag=1 #indicate whether gridmeter on this device, default 1; when rcv gridmeter data from link, set 0

class lnxall_dev:
    def __init__(self, sn, capacity, power, fire):
        self.sn = sn  #str
        self.capacity = capacity #uint [0-65535]
        self.power = power #uint[0-65535]
        self.fire = fire #uint8[0-1]

    def encode(self):
        payload=LNXALL_DEV_SEGID_SN.to_bytes(1,BIG_ENDIAN)+len(self.sn).to_bytes(1,BIG_ENDIAN)+bytes(self.sn,encoding='utf-8')
        payload+=LNXALL_DEV_SEGID_CAP.to_bytes(1,BIG_ENDIAN)+b'\2'+self.capacity.to_bytes(2,BIG_ENDIAN)
        payload+=LNXALL_DEV_SEGID_PWR.to_bytes(1,BIG_ENDIAN)+b'\2'+self.power.to_bytes(2,BIG_ENDIAN)
        payload+=LNXALL_DEV_SEGID_FIRE.to_bytes(1,BIG_ENDIAN)+b'\1'+self.fire.to_bytes(1,BIG_ENDIAN)
        return payload

    @staticmethod
    def decode(devdata):
        dev={}
        index=0
        print(devdata)
        while True:
            segid=devdata[index]
            seglen=devdata[index+1]
            dev[segid]=devdata[index+2:index+2+seglen]
            index+=2+seglen
            if index >= len(devdata):
                break
        for i in (LNXALL_DEV_SEGID_SN,LNXALL_DEV_SEGID_FIRE):
            if i not in dev:
                print("segid",i,"notfound")
                return {}
        return dev

class lnxall_meter:
    def __init__(self, data):
        self.data = data #bytes

    def encode(self):
        payload=len(self.data).to_bytes(2,BIG_ENDIAN)+self.data
        return payload

    @staticmethod
    def decode(devdata):
        meter={}
        datalen = int.from_bytes(devdata[0:2],BIG_ENDIAN)
        meter['data']=devdata[2:]
        return meter

def on_connect(client, userdata, flags, rc):
    print("Connected with result code "+str(rc))
    if pcssn == '':
        mqttclient.subscribe('/ems/#')
    else:
        mqttclient.subscribe('/ems/{}/+'.format(gridmetersn))
        mqttclient.subscribe('/ems/{}/+'.format(xlmetersn))
        mqttclient.subscribe('/ems/{}/+'.format(pcssn))

def reconnect():
    global mqttclient,pcssn
    while True:
        try:
            if mqttclient.reconnect() == 0:
                break
        except OSError as e:
            print(f"发生了 OSError 错误：{e}")
            time.sleep(10)

def check_connection():
    if not mqttclient.is_connected():
        print("MQTT connection lost. Reconnecting...")
        reconnect()
    else:
        print("MQTT connection is active.")
    threading.Timer(60, check_connection).start()

def on_disconnect(client, userdata, rc):
    if rc != 0:
        print(f"Unexpected MQTT disconnection. Result code: {rc}")    

def on_message(client, userdata, msg):
    #print(type(msg.payload))
    print(msg.topic+" "+str(msg.payload))
    jsonobj=json.loads(msg.payload)
    if msg.topic.find('PCS')>0 and 'identifier' in jsonobj and jsonobj['identifier']!='Read101_125Data':
        return
    #tags=jsonobj['tags']
    #for tag in tags:
        #if tag.find('TAG')>0:
            #meterdata[tag]=tags[tag]

    packet=enpacket(BROADCAST,macaddr,LNXALL_JSONMETER,msg.payload)
    print(packet)
    for intf in socks:
        socks[intf].send(packet)

def on_message_findmeter(client, userdata, msg):
    #print(msg.payload)
    global gmsn,bmsmetersn,pcssn,gmflag
    #if msg.topic.find("gridmeter")>0:
        #global metersn
        #for s in msg.topic.split('/'):
            #if s.find('gridmeter')>0:
                #metersn=s
                #break
        #print('find metersn ', metersn, gmsn+'_gridmeter')
        #if gmsn == '':
            #return
        #if metersn==gmsn+'_gridmeter' and gmflag==1 :
            #client.unsubscribe('/ems/#')
            #client.on_message=on_message
            #client.subscribe('/ems/{}/+'.format(metersn))
            #if bmsmetersn == '':
                #bmsmetersn=gmsn+'_bmsmeter'
            #client.subscribe('/ems/{}/+'.format(bmsmetersn))
        #elif gmflag == 0:
            #client.unsubscribe('/ems/#')
            #client.on_message=on_message
            #if bmsmetersn == '':
                #bmsmetersn=gmsn+'_bmsmeter'
            #client.subscribe('/ems/{}/+'.format(bmsmetersn))
    #elif msg.topic.find('PCS')>0:
    if msg.topic.find('PCS')>0:
        jsonobj=json.loads(msg.payload)
        gmsn=jsonobj['gw_sn']
        client.unsubscribe('/ems/#')
        client.on_message=on_message
        client.subscribe('/ems/{}/+'.format(gridmetersn))
        client.subscribe('/ems/{}/+'.format(xlmetersn))
        if pcssn == '':
            pcssn=gmsn+'_PCS1'
        client.subscribe('/ems/{}/+'.format(pcssn))
        #if bmsmetersn == '':
            #bmsmetersn=gmsn+'_bmsmeter'
        #client.subscribe('/ems/{}/+'.format(bmsmetersn))

class meter_thread(threading.Thread):
    def __init__(self):
        global mqttclient
        mqttclient.on_connect=on_connect
        mqttclient.on_disconnect = on_disconnect
        mqttclient.on_message=on_message_findmeter
        mqttclient.username_pw_set(username=MQTT_SERVER_USERNAME,password=MQTT_SERVER_PASSWORD)
        while True:
            try:
                if mqttclient.connect(MQTT_SERVER_IP,MQTT_SERVER_PORT) == 0:
                    break
            except socket.timeout:
                print("socket connect time out")
                time.sleep(10)
            except ConnectionRefusedError:
                print("port refused")
                time.sleep(10)
            except OSError as e:
                print(f"发生了 OSError 错误：{e}")
                time.sleep(10)

        check_connection()
        mqttclient.subscribe('/ems/#')
        threading.Thread.__init__(self)

    def run(self):
        global mqttclient
        mqttclient.loop_forever()

def crc16(data):
    crc=0
    if len(data)==0:
        return 0
    for i in range(len(data)):
        R=data[i]
        for j in range(8):
            if R>127:
                k=1
            else:
                k=0
            R=(R<<1)&0xff
            if crc>0x7fff:
                m=1
            else:
                m=0
            if k+m==1:
                k=1
            else:
                k=0
            crc=(crc<<1)&0xffff
            if k==1:
                crc^=0x1021
    return crc

def macaddress(mac:int):
    return ':'.join('{:02x}'.format(a) for a in mac.to_bytes(6, BIG_ENDIAN))

def devdiscover(macaddr,mydev):
    global gmsn
    packet=enpacket(BROADCAST,macaddr,LNXALL_DEV,mydev.encode())
    print(macaddr, packet)
    for intf in socks:
        socks[intf].send(packet)
    threading.Timer(LNXALL_DISCOVER_INVL, devdiscover, (macaddr, mydev)).start()
    if gmsn != '':
        emstopic='/ems/{}_gridmeter/service/set'.format(gmsn)
        cobj={}
        cobj['sn']=gmsn+'_gridmeter'
        cobj['gw_sn']=gmsn
        cobj['devcount']=len(lnxalldevs)+1
        cobj['identifier']='devcount'
        cobj['data_type']='service'
        cobj['mi']=0
        cobj['time']=int(time.time())
        emsmsg=json.dumps(cobj)
        mqttclient.publish(emstopic,emsmsg)
        print('publish: ',emstopic, emsmsg)

def devresolve(srcmac:int, packet, intf):
    global combosn
    print("received discover from ",macaddress(srcmac))
    if srcmac not in lnxalldevs:
        pkt=enpacket(srcmac.to_bytes(6,BIG_ENDIAN),macaddr,LNXALL_DEV,lnxall_dev(combosn,capacity,power,fire).encode())
        socks[intf].send(pkt)
    newdev=lnxall_dev.decode(packet)
    lnxalldevs[srcmac]=newdev

def meterresolve(srcmac:int, packet):
    if srcmac in lnxalldevs:
        dev=lnxalldevs[srcmac]
        dev['meter']=lnxall_meter.decode(packet)
        print(packet)
        global gmsn
        if len(gmsn)>0:
            if packet.find(b'meter')>0:
                gmflag=0
                if packet.find(b'xlmeter')>0:
                    fmetersn=gmsn+'_xlmeter'
                else:
                    fmetersn=gmsn+'_gridmeter'
                emstopic='/ems/{}/service'.format(fmetersn)
                jsonobj=json.loads(packet)
                jsonobj['sn']=fmetersn
                jsonobj['gw_sn']=gmsn
                emsmsg=json.dumps(jsonobj)
                mqttclient.publish(emstopic,emsmsg)
                print('publish: ',emstopic, emsmsg)

                jsonobj['tag_node']=json.dumps(jsonobj['tags'])
                del jsonobj['tags']
                del jsonobj['gw_sn']
                del jsonobj['data_type']
                jsonobj['appKey']=1
                mmetertopic='/ems/{}/service/set'.format(fmetersn)
                mmetermsg=json.dumps(jsonobj)
                mqttclient.publish(mmetertopic,mmetermsg)
                print('publish: mmeter ',mmetertopic,mmetermsg)
            elif packet.find(b'PCS1')>0:
                jsonobj=json.loads(packet)
                if 'tags' in jsonobj:
                    jsonobj['tag_node']=json.dumps(jsonobj['tags'])
                    del jsonobj['tags']
                    mmetertopic='/ems/{}_PCS1/service/set'.format(gmsn)
                    mmetermsg=json.dumps(jsonobj)
                    mqttclient.publish(mmetertopic,mmetermsg)
                    print('publish',mmetertopic,mmetermsg)
    else:
        print("dev ",macaddress(srcmac)," not in lnxalldevs")

def resolve(packet, intf):
    dstmac = int.from_bytes(packet[:6], BIG_ENDIAN)
    srcmac = int.from_bytes(packet[6:12], BIG_ENDIAN)
    typeid = int.from_bytes(packet[12:14], BIG_ENDIAN)
    print('dstmac', macaddress(dstmac), ' srcmac', macaddress(srcmac),typeid)
    if typeid == LNXALL_ETH_TYPE:
        msgtype=packet[14]
        datalen=int.from_bytes(packet[15:17],BIG_ENDIAN)
        data=packet[17:12+datalen]
        if msgtype == LNXALL_DEV:
            devresolve(srcmac, packet[17:12+datalen], intf)
            print("after dev resolve",lnxalldevs)
        elif msgtype == LNXALL_645METER:
            meterresolve(srcmac, packet[17:12+datalen])
            print("after meter resolve", lnxalldevs)
        elif msgtype == LNXALL_485METER:
            meterresolve(srcmac, packet[17:12+datalen])
        elif msgtype == LNXALL_JSONMETER:
            meterresolve(srcmac, packet[17:12+datalen])

    else:
        print("unexpected typeid", typeid)

def enpacket(dstmac, srcmac, msgtype, data):
    #print(type(data),type(dstmac),type(srcmac),type(msgtype))
    #print(data,dstmac,srcmac,msgtype)
    paklen=1+2+len(data)+2
    pakdata=struct.pack("!BH{}s".format(len(data)),msgtype,paklen,data)
    crc=crc16(pakdata)
    retpkt=struct.pack("!6s6sH",dstmac,srcmac,LNXALL_ETH_TYPE)+pakdata+crc.to_bytes(2,BIG_ENDIAN)
    ethlen=14+paklen
    if ethlen<64:
        retpkt+=bytes([255]*(64-ethlen))
    return retpkt

def main():
    #interfaces = netifaces.interfaces()
    #for intf in interfaces:
        #print(intf, netifaces.ifaddresses(intf))
        #if netifaces.AF_INET in netifaces.ifaddresses(intf):
            #print(intf, "is up")
            #newsock = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(LNXALL_ETH_TYPE))
            #newsock.bind((intf, 0))
            #socks[intf] = newsock

    global combosn
    facfile=open(facfilename,'r')
    if facfile:
        facinfo=facfile.read()
        snindex=facinfo.find('sn = ')+5
        combosn=facinfo[snindex:snindex+12]
        print(combosn)
        facfile.close()

    mydev=lnxall_dev(combosn,capacity,power,fire)

    meterthd=meter_thread()
    meterthd.start()

    global macaddr
    addrs=netifaces.ifaddresses(intf)
    macstr=addrs[netifaces.AF_PACKET][0]['addr']
    macaddr=int(macstr.replace(':',''),16).to_bytes(6,BIG_ENDIAN)
    print("macstr ",intf,": ",macaddr)

    sendsock = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(LNXALL_ETH_TYPE))
    sendsock.bind((intf,0))
    socks[intf]=sendsock
    devdiscover(macaddr,mydev)

    rawsock = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(LNXALL_ETH_TYPE))
    while True:
        try:
            packet, packet_info = rawsock.recvfrom(1500)
            resolve(packet, packet_info[0])
        except KeyboardInterrupt:
            break
    meterthd.join()

if __name__ == '__main__':
    main()
