
# swot_gamma_pipeline.py
import os, math, sys, shutil
import numpy as np
sys.path.insert(0, '/usr/local/GAMMA_20241205')
import py_gamma as pg

def pget(par_path, key, idx=0, cast=float):
    val = pg.ParFile(par_path).get_value(key)
    if isinstance(val, (list, tuple)): v = val[idx]
    else: v = val
    return cast(v)

def ensure_dir(path):
    os.makedirs(path, exist_ok=True)
    return os.path.abspath(path)

# ---------- user-configurable defaults ----------
DEFAULTS = dict(
    # multilook
    r_dec=1, 
    az_dec=5, 
    rwin=3, 
    azwin=15,
    # ORB
    orb_interval="0.5", 
    orb_extra=150, 
    orb_mode=2,
    # deskew
    interp=0, 
    order=5, 
    deramp=1, 
    ph_corr=1,
)

class SLC:
    def __init__(self, filepath, savepath=None, imgname=None):
        self.datapath = os.path.abspath(filepath)

        # savepath
        filename_parts = os.path.split(filepath)[-1].split("_")
        if savepath is None:
            savepath = "_".join(filename_parts[:7]) + "_output_v0"
        self.savepath = ensure_dir(savepath)

        # image name
        if imgname is None:
            imgname = filename_parts[7].split("T")[0]
        self.img = os.path.join(self.savepath, imgname)

        # raw (skewed) SLCs produced by par_SWOT_SLC
        self.r_minus     = f"{self.img}_R_minus_y.slc"
        self.r_minus_par = f"{self.img}_R_minus_y.slc.par"
        self.r_plus      = f"{self.img}_R_plus_y.slc"
        self.r_plus_par  = f"{self.img}_R_plus_y.slc.par"

        # deskewed legs
        self.ra_slc = f"{self.img}_RA.slc"
        self.ra_slc_par = f"{self.img}_RA.slc.par"
        self.rb_slc = f"{self.img}_RB.slc"
        self.rb_slc_par = f"{self.img}_RB.slc.par"

        # MLIs
        self.ra_mli = f"{self.img}_RA.mli"
        self.ra_mli_par = f"{self.img}_RA.mli.par"
        self.rb_mli = f"{self.img}_RB.mli"
        self.rb_mli_par = f"{self.img}_RB.mli.par"

        # offsets
        self.offset = f"{self.img}.RA_RB.off"

        # SCH DEM from par_SWOT_SLC
        self.sch_dem     = f"{self.img}.sch.dem"
        self.sch_dem_par = f"{self.img}.sch.dem_par"

    # ---- Stage 1: SLC initial processing ----
    # Import SLC from NetCDF and process it
    def import_from_nc(self):
        pg.par_SWOT_SLC(self.datapath, self.img, self.sch_dem, self.sch_dem_par)
        print("[par_SWOT_SLC] wrote SLC legs + SCH DEM")

    # Orbit Propagation
    def propagate_orbit(
            self, 
            interval=DEFAULTS["orb_interval"], 
            extra=DEFAULTS["orb_extra"], 
            mode=DEFAULTS["orb_mode"]
            ):
        
        pg.ORB_prop_SLC(self.r_minus_par, "-", interval, extra, mode)
        pg.ORB_prop_SLC(self.r_plus_par,  "-", interval, extra, mode)
        print(f"[ORB_prop_SLC] interval={interval}s extra={extra}s mode={mode}")

    # Deskew SLC
    def deskew(
            self, 
            interp=DEFAULTS["interp"], 
            order=DEFAULTS["order"], 
            deramp=DEFAULTS["deramp"], 
            ph_corr=DEFAULTS["ph_corr"]
            ):
        
        nrA = pget(self.r_minus_par, "near_range_slc")
        nrB = pget(self.r_plus_par,  "near_range_slc")

        pg.SLC_deskew(self.r_minus, self.r_minus_par, self.ra_slc, self.ra_slc_par, 
                      0, interp, order, deramp, ph_corr, nrA)
        pg.SLC_deskew(self.r_plus, self.r_plus_par, self.rb_slc, self.rb_slc_par,
                      0, interp, order, deramp, ph_corr, nrB)
        print(f"[SLC_deskew] order={order} Lanczos, deramp={deramp}, ph_corr={ph_corr}")

if __name__ == "__main__":
        

    filename = f"data/SWOT_L1B_HR_SLC_036_001_237R_20250720T123100_20250720T123111_PID0_01.nc"
    dempath = f"dem/dem_alps.tif"
    epsg = 4326
    post_deg=0.0002777777 

    slc = SLC(filename)
    slc.import_from_nc()
    slc.propagate_orbit()
    slc.deskew()

