import socket
import sys
import threading
import time
import netifaces
import paho.mqtt.client as mqtt
import json
import struct

intf='enp1s0'
gridmetersn='21881E0009F5_gridmeter'
xlmetersn='21881E00048E_xlmeter'

#---LNXALL private ethernet protocol
#-------------------------------------------------------------------------------------------------------
#---------|---dst mac(6 bytes)---|---src mac(6 bytes)---|---eth type(2 bytes)---|
#---------|---proto type(1 byte)---|---len(2 byte)---|---data---|---crc16(2 bytes)---|---padding---|
#-------------------------------------------------------------------------------------------------------
#--proto type: 0----|---discover data len(2 byte)---|---dev segment data---|
#      dev segment data:---|---dev segment id(1 byte)---|---segment len(1 byte)---|---segment data---|
#          dev segment id: 0-SN 1-capacity 2-power 3-fire ......
#--proto type: 1----|---645 data len(2 byte)---|---645 meter data---|
#--proto type: 2----|---485 data len(2 byte)---|---485 meter data---|

BIG_ENDIAN = 'big'
LITTLE_ENDIAN = 'little'

LNXALL_DISCOVER_INVL=30

#eth type
LNXALL_ETH_TYPE = 0x8120 #ethernet type for lnxall private ethernet protocol

#proto type
LNXALL_DEV = 0x00 #device discover and manage
LNXALL_645METER = 0x01 #meter data in 645 format
LNXALL_485METER = 0x02 #meter data in 485 format
LNXALL_JSONMETER = 0x03 #meter data in json format

#dev segment id
LNXALL_DEV_SEGID_SN = 0x00
LNXALL_DEV_SEGID_CAP = 0x01
LNXALL_DEV_SEGID_PWR = 0x02
LNXALL_DEV_SEGID_FIRE = 0x03
LNXALL_DEV_SEGID_GMSN = 0x04

#meter data type
#LNXALL_METER_RAWDATA = 0x00 #indicate meter data as raw protocol data
#LNXALL_METER_JSON = 0x01 #indicate meter data in json format

capacity = 215 #kwh
power = 100 #kw
fire = 0

macaddr=b"\xff\xff\xff\xff\xff\xff"
BROADCAST=b"\xff\xff\xff\xff\xff\xff"

facfilename='/app/config/fac.ini'
MQTT_SERVER_IP='172.18.39.101'
MQTT_SERVER_PORT=3883
MQTT_SERVER_USERNAME='localuser'
MQTT_SERVER_PASSWORD='dywl@galaxy'
metersn=''
bmsmetersn=''
pcssn=''
meterdata={}
gmsn='' #gridmeter gateway sn


combosn='' #combo sn

socks={}
lnxalldevs={}
clientid=gmsn+'_ETHPROTO'
mqttclient=mqtt.Client(client_id = clientid,clean_session = True)

gmflag=1 #indicate whether gridmeter on this device, default 1; when rcv gridmeter data from link, set 0
is_master = False  # 主机标识：当gridmetersn包含gmsn时为True

reconnect_attempts = 0
last_successful_communication = time.time() 
last_gridinfo_publish = 0  # 上次GridGroup消息发布时间

CHECK_INTERVAL_NORMAL = 30
periodic_timer = None

class lnxall_dev:
    def __init__(self, sn, capacity, power, fire, gmsn=None):
        self.sn = sn  #str
        self.capacity = capacity #uint [0-65535]
        self.power = power #uint[0-65535]
        self.fire = fire #uint8[0-1]
        self.gmsn = gmsn if gmsn else sn  #gmsn，默认为sn

    def encode(self):
        payload=LNXALL_DEV_SEGID_SN.to_bytes(1,BIG_ENDIAN)+len(self.sn).to_bytes(1,BIG_ENDIAN)+bytes(self.sn,encoding='utf-8')
        payload+=LNXALL_DEV_SEGID_CAP.to_bytes(1,BIG_ENDIAN)+b'\2'+self.capacity.to_bytes(2,BIG_ENDIAN)
        payload+=LNXALL_DEV_SEGID_PWR.to_bytes(1,BIG_ENDIAN)+b'\2'+self.power.to_bytes(2,BIG_ENDIAN)
        payload+=LNXALL_DEV_SEGID_FIRE.to_bytes(1,BIG_ENDIAN)+b'\1'+self.fire.to_bytes(1,BIG_ENDIAN)
        # 添加gmsn字段传输
        payload+=LNXALL_DEV_SEGID_GMSN.to_bytes(1,BIG_ENDIAN)+len(self.gmsn).to_bytes(1,BIG_ENDIAN)+bytes(self.gmsn,encoding='utf-8')
        return payload

    @staticmethod
    def decode(devdata):
        dev={}
        index=0
        print(devdata)
        while True:
            segid=devdata[index]
            seglen=devdata[index+1]
            dev[segid]=devdata[index+2:index+2+seglen]
            index+=2+seglen
            if index >= len(devdata):
                break
        for i in (LNXALL_DEV_SEGID_SN,LNXALL_DEV_SEGID_FIRE):
            if i not in dev:
                print("segid",i,"notfound")
                return {}
        if LNXALL_DEV_SEGID_GMSN in dev:
            dev['gmsn_str'] = dev[LNXALL_DEV_SEGID_GMSN].decode('utf-8')
        return dev

class lnxall_meter:
    def __init__(self, data):
        self.data = data #bytes

    def encode(self):
        payload=len(self.data).to_bytes(2,BIG_ENDIAN)+self.data
        return payload

    @staticmethod
    def decode(devdata):
        meter={}
        datalen = int.from_bytes(devdata[0:2],BIG_ENDIAN)
        meter['data']=devdata[2:]
        return meter

def on_connect(client, userdata, flags, rc):
    print("Connected with result code "+str(rc))
    mqttclient.subscribe('/ems/{}/+'.format(gridmetersn))
    mqttclient.subscribe('/ems/{}/+'.format(xlmetersn))
    mqttclient.subscribe('/ems/{}/+'.format(pcssn))

def reconnect():
    global reconnect_attempts
    try:
        print(f"Reconnect attempt {reconnect_attempts + 1}")
        mqttclient.reconnect()
    except Exception as e:
        print(f"Reconnect failed: {e}")

def test_mqtt_connection(timeout=5):
    try:
        result = mqttclient.publish("test/connection", payload="heartbeat", qos=0, retain=False)
        result.wait_for_publish(timeout=timeout)
        return True
    except Exception as e:
        print(f"Connection test failed: {e}")
        return False

def check_connection():
    global reconnect_attempts, last_successful_communication
    try:
        if not mqttclient.is_connected() or not test_mqtt_connection():
            if reconnect_attempts < 10:
                print("MQTT client is disconnected or stale. exit")
                reconnect_attempts += 1
            else:
                sys.exit(1)
        else:
            if (time.time() - last_successful_communication) > CHECK_INTERVAL_NORMAL:
                print("MQTT connection is active.")
                last_successful_communication = time.time()
                reconnect_attempts = 0
    except Exception as e:
        print(f"Error in connection check: {e}")

def on_disconnect(client, userdata, rc):
    print("MQTT client disconnected with code %s", rc)

def on_message(client, userdata, msg):
    #print(type(msg.payload))
    print(msg.topic+" "+str(msg.payload))
    
    # 30秒间隔上报GridGroup消息 - 仅主机执行
    global last_gridinfo_publish, combosn, gmsn, lnxalldevs, is_master
    current_time = time.time()
    
    if is_master and current_time - last_gridinfo_publish >= 30:
        # 构建slave_SN_list
        slave_SN_list = []
        for srcmac, dev_info in lnxalldevs.items():
            if 'gmsn' in dev_info:
                slave_SN_list.append({"slave_SN": dev_info['gmsn']})
        
        # 构建payload
        payload = {
            "mi": 2025,
            "timestamp": int(current_time),
            "identifier": "GridGroup",
            "sn": f"{combosn}_EMSCTRL",
            "tags": {
                "en-peer-to-peer": 1,
                "master_SN": gmsn,
                "slave_SN_list": slave_SN_list
            },
            "gw_sn": combosn,
            "data_type": "service"
        }
        
        # 上报topic
        topic = f"/ems/{combosn}_EMSCTRL/service"
        
        # 发布消息
        try:
            emsmsg = json.dumps(payload)
            mqttclient.publish(topic, emsmsg)
            print(f"Master published GridGroup to {topic}: {emsmsg}")
            last_gridinfo_publish = current_time  # 更新上次发布时间
        except Exception as e:
            print(f"Failed to publish GridGroup: {e}")
    
    # 原有消息处理逻辑
    try:
        jsonobj=json.loads(msg.payload)
        if msg.topic.find('PCS')>0 and 'identifier' in jsonobj and jsonobj['identifier']!='Read101_125Data':
            return
        
        packet=enpacket(BROADCAST,macaddr,LNXALL_JSONMETER,msg.payload)
        print(packet)
        for intf in socks:
            socks[intf].send(packet)
    except Exception as e:
        print(f"Error processing message: {e}")

# def on_message_findmeter(client, userdata, msg):
#     #print(msg.payload)
#     global gmsn,bmsmetersn,pcssn,gmflag
#     #if msg.topic.find("gridmeter")>0:
#         #global metersn
#         #for s in msg.topic.split('/'):
#             #if s.find('gridmeter')>0:
#                 #metersn=s
#                 #break
#         #print('find metersn ', metersn, gmsn+'_gridmeter')
#         #if gmsn == '':
#             #return
#         #if metersn==gmsn+'_gridmeter' and gmflag==1 :
#             #client.unsubscribe('/ems/#')
#             #client.on_message=on_message
#             #client.subscribe('/ems/{}/+'.format(metersn))
#             #if bmsmetersn == '':
#                 #bmsmetersn=gmsn+'_bmsmeter'
#             #client.subscribe('/ems/{}/+'.format(bmsmetersn))
#         #elif gmflag == 0:
#             #client.unsubscribe('/ems/#')
#             #client.on_message=on_message
#             #if bmsmetersn == '':
#                 #bmsmetersn=gmsn+'_bmsmeter'
#             #client.subscribe('/ems/{}/+'.format(bmsmetersn))
#     #elif msg.topic.find('PCS')>0:
#     if msg.topic.find('PCS')>0:
#         jsonobj=json.loads(msg.payload)
#         client.unsubscribe('/ems/#')
#         client.on_message=on_message
#         client.subscribe('/ems/{}/+'.format(gridmetersn))
#         client.subscribe('/ems/{}/+'.format(xlmetersn))
#         if pcssn == '':
#             pcssn=gmsn+'_PCS1'
#         client.subscribe('/ems/{}/+'.format(pcssn))
#         #if bmsmetersn == '':
#             #bmsmetersn=gmsn+'_bmsmeter'
#         #client.subscribe('/ems/{}/+'.format(bmsmetersn))


class meter_thread(threading.Thread):
    def __init__(self):
        global mqttclient,pcssn
        pcssn=gmsn+'_PCS1'
        mqttclient.on_connect=on_connect
        mqttclient.on_disconnect = on_disconnect
        mqttclient.on_message=on_message
        mqttclient.username_pw_set(username=MQTT_SERVER_USERNAME,password=MQTT_SERVER_PASSWORD)
        while True:
            try:
                if mqttclient.connect(MQTT_SERVER_IP,MQTT_SERVER_PORT,keepalive=60) == 0:
                    break
            except socket.timeout:
                print("socket connect time out")
                time.sleep(10)
            except ConnectionRefusedError:
                print("port refused")
                time.sleep(10)
            except OSError as e:
                print(f"发生了 OSError 错误：{e}")
                time.sleep(10)
        time.sleep(5)
        threading.Thread.__init__(self)

    def run(self):
        global mqttclient
        mqttclient.loop_forever()

def crc16(data):
    crc=0
    if len(data)==0:
        return 0
    for i in range(len(data)):
        R=data[i]
        for j in range(8):
            if R>127:
                k=1
            else:
                k=0
            R=(R<<1)&0xff
            if crc>0x7fff:
                m=1
            else:
                m=0
            if k+m==1:
                k=1
            else:
                k=0
            crc=(crc<<1)&0xffff
            if k==1:
                crc^=0x1021
    return crc

def macaddress(mac:int):
    return ':'.join('{:02x}'.format(a) for a in mac.to_bytes(6, BIG_ENDIAN))

def devdiscover(macaddr,mydev):
    global gmsn
    packet=enpacket(BROADCAST,macaddr,LNXALL_DEV,mydev.encode())
    print(macaddr, packet)
    for intf in socks:
        socks[intf].send(packet)
    threading.Timer(LNXALL_DISCOVER_INVL, devdiscover, (macaddr, mydev)).start()
    if gmsn != '':
        emstopic='/ems/{}_gridmeter/service/set'.format(gmsn)
        cobj={}
        cobj['sn']=gmsn+'_gridmeter'
        cobj['gw_sn']=gmsn
        cobj['devcount']=len(lnxalldevs)+1
        cobj['identifier']='devcount'
        cobj['data_type']='service'
        cobj['mi']=0
        cobj['time']=int(time.time())
        emsmsg=json.dumps(cobj)
        mqttclient.publish(emstopic,emsmsg)
        print('publish: ',emstopic, emsmsg)

def devresolve(srcmac:int, packet, intf):
    global combosn, gmsn
    print("received discover from ",macaddress(srcmac))
    if srcmac not in lnxalldevs:
        # 响应设备发现请求，包含gmsn
        pkt=enpacket(srcmac.to_bytes(6,BIG_ENDIAN),macaddr,LNXALL_DEV,lnxall_dev(combosn,capacity,power,fire,gmsn).encode())
        socks[intf].send(pkt)
    newdev=lnxall_dev.decode(packet)
    dev_info = {
        'device': newdev,
        'last_active': time.time()
    }
    
    # 记录从机的gmsn到设备信息中
    if newdev and 'gmsn_str' in newdev:
        dev_info['gmsn'] = newdev['gmsn_str']
        print(f"Recorded gmsn {newdev['gmsn_str']} for device {macaddress(srcmac)}")
    
    lnxalldevs[srcmac] = dev_info

def meterresolve(srcmac:int, packet):
    if srcmac in lnxalldevs:
        dev=lnxalldevs[srcmac]
        dev['meter']=lnxall_meter.decode(packet)
        print(packet)
        global gmsn
        if len(gmsn)>0:
            if packet.find(b'meter')>0:
                gmflag=0
                if packet.find(b'xlmeter')>0:
                    fmetersn=gmsn+'_xlmeter'
                else:
                    fmetersn=gmsn+'_gridmeter'
                emstopic='/ems/{}/service'.format(fmetersn)
                jsonobj=json.loads(packet)
                jsonobj['sn']=fmetersn
                jsonobj['gw_sn']=gmsn
                emsmsg=json.dumps(jsonobj)
                mqttclient.publish(emstopic,emsmsg)
                print('publish: ',emstopic, emsmsg)

                jsonobj['tag_node']=json.dumps(jsonobj['tags'])
                del jsonobj['tags']
                del jsonobj['gw_sn']
                del jsonobj['data_type']
                jsonobj['appKey']=1
                mmetertopic='/ems/{}/service/set'.format(fmetersn)
                mmetermsg=json.dumps(jsonobj)
                mqttclient.publish(mmetertopic,mmetermsg)
                print('publish: mmeter ',mmetertopic,mmetermsg)
            elif packet.find(b'PCS1')>0:
                jsonobj=json.loads(packet)
                if 'tags' in jsonobj:
                    jsonobj['tag_node']=json.dumps(jsonobj['tags'])
                    del jsonobj['tags']
                    mmetertopic='/ems/{}_PCS1/service/set'.format(gmsn)
                    mmetermsg=json.dumps(jsonobj)
                    mqttclient.publish(mmetertopic,mmetermsg)
                    print('publish',mmetertopic,mmetermsg)
    else:
        print("dev ",macaddress(srcmac)," not in lnxalldevs")

def resolve(packet, intf):
    dstmac = int.from_bytes(packet[:6], BIG_ENDIAN)
    srcmac = int.from_bytes(packet[6:12], BIG_ENDIAN)
    typeid = int.from_bytes(packet[12:14], BIG_ENDIAN)
    print('dstmac', macaddress(dstmac), ' srcmac', macaddress(srcmac),typeid)
    for mac in lnxalldevs:
        print('lnxalldevs:',macaddress(mac))
    if typeid == LNXALL_ETH_TYPE:
        msgtype=packet[14]
        datalen=int.from_bytes(packet[15:17],BIG_ENDIAN)
        data=packet[17:12+datalen]
        if msgtype == LNXALL_DEV:
            devresolve(srcmac, packet[17:12+datalen], intf)
            print("after dev resolve",lnxalldevs)
        elif msgtype == LNXALL_645METER:
            meterresolve(srcmac, packet[17:12+datalen])
            print("after meter resolve", lnxalldevs)
        elif msgtype == LNXALL_485METER:
            meterresolve(srcmac, packet[17:12+datalen])
        elif msgtype == LNXALL_JSONMETER:
            meterresolve(srcmac, packet[17:12+datalen])

    else:
        print("unexpected typeid", typeid)

def enpacket(dstmac, srcmac, msgtype, data):
    #print(type(data),type(dstmac),type(srcmac),type(msgtype))
    #print(data,dstmac,srcmac,msgtype)
    paklen=1+2+len(data)+2
    pakdata=struct.pack("!BH{}s".format(len(data)),msgtype,paklen,data)
    crc=crc16(pakdata)
    retpkt=struct.pack("!6s6sH",dstmac,srcmac,LNXALL_ETH_TYPE)+pakdata+crc.to_bytes(2,BIG_ENDIAN)
    ethlen=14+paklen
    if ethlen<64:
        retpkt+=bytes([255]*(64-ethlen))
    return retpkt

def start_periodic_check(interval=10):
    def run():
        try:
            check_connection()
        except Exception as e:
            print(f"Error in periodic check: {e}")
        finally:
            # 重新设置定时器，而不是递归创建
            global periodic_timer
            periodic_timer = threading.Timer(interval, run)
            periodic_timer.daemon = True
            periodic_timer.start()
    
    global periodic_timer
    periodic_timer = threading.Timer(interval, run)
    periodic_timer.daemon = True
    periodic_timer.start()
    return periodic_timer

def check_devices():
    global lnxalldevs
    current_time = time.time()
    timeout = 90
    
    offline_devices = []
    for srcmac, dev_info in lnxalldevs.items():
        if current_time - dev_info['last_active'] > timeout:
            offline_devices.append(srcmac)
    
    for srcmac in offline_devices:
        del lnxalldevs[srcmac]
        print(f"Device {macaddress(srcmac)} removed due to timeout")
    
    timer = threading.Timer(60, check_devices)
    timer.daemon = True
    timer.start()

def main():
    #interfaces = netifaces.interfaces()
    #for intf in interfaces:
        #print(intf, netifaces.ifaddresses(intf))
        #if netifaces.AF_INET in netifaces.ifaddresses(intf):
            #print(intf, "is up")
            #newsock = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(LNXALL_ETH_TYPE))
            #newsock.bind((intf, 0))
            #socks[intf] = newsock

    global combosn, is_master, gridmetersn, gmsn
    facfile=open(facfilename,'r')
    if facfile:
        facinfo=facfile.read()
        snindex=facinfo.find('sn = ')+5
        combosn=facinfo[snindex:snindex+12]
        print(combosn)
        facfile.close()
    
    # 初始化主机标识：当gridmetersn包含gmsn时，标识为主机
    is_master = gmsn in gridmetersn
    print(f"is_master: {is_master}, gridmetersn: {gridmetersn}, gmsn: {gmsn}")

    mydev=lnxall_dev(combosn,capacity,power,fire,gmsn)

    meterthd=meter_thread()
    meterthd.start()

    global macaddr
    addrs=netifaces.ifaddresses(intf)
    macstr=addrs[netifaces.AF_PACKET][0]['addr']
    macaddr=int(macstr.replace(':',''),16).to_bytes(6,BIG_ENDIAN)
    print("macstr ",intf,": ",macaddr)

    sendsock = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(LNXALL_ETH_TYPE))
    sendsock.bind((intf,0))
    socks[intf]=sendsock
    devdiscover(macaddr,mydev)
    start_periodic_check(10)
    check_devices()
    while True:
        try:
            rawsock = socket.socket(socket.PF_PACKET, socket.SOCK_RAW, socket.htons(LNXALL_ETH_TYPE))
            rawsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            rawsock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536)
            rawsock.bind((intf, 0))

            print("Socket initialized, starting to receive packets...")

            while True:
                try:
                    packet, packet_info = rawsock.recvfrom(2048)
                    resolve(packet, packet_info[0])
                except KeyboardInterrupt:
                    print("User interrupted, exiting...")
                    meterthd.join()
                    return
                except socket.error as e:
                    print(f"Socket error: {e}, attempting to reconnect...")
                    break  # 跳出内层循环，触发 socket 重建
                except Exception as e:
                    print(f"Unexpected error: {e}, restarting socket...", exc_info=True)
                    break  # 跳出内层循环，触发 socket 重建

        except Exception as outer_error:
            print(f"Error initializing socket: {outer_error}")
        finally:
            try:
                rawsock.close()
            except:
                pass

        print("Reinitializing socket in 1 second...")
        time.sleep(1) 


if __name__ == '__main__':
    main()

