From 759f768f86038c165bd8924a92e1ab98f861e5c7 Mon Sep 17 00:00:00 2001 From: Michael Chavinda Date: Sat, 14 Mar 2026 12:06:10 -0700 Subject: [PATCH] feat: Add cast function --- src/DataFrame/Functions.hs | 15 +++ src/DataFrame/Internal/Expression.hs | 12 ++ src/DataFrame/Internal/Interpreter.hs | 155 ++++++++++++++++++++++++++ 3 files changed, 182 insertions(+) diff --git a/src/DataFrame/Functions.hs b/src/DataFrame/Functions.hs index b06ce7c..4458730 100644 --- a/src/DataFrame/Functions.hs +++ b/src/DataFrame/Functions.hs @@ -98,6 +98,21 @@ lift2Decorated f name rep comm prec = } ) +{- | Lenient numeric / text coercion. Looks up column @name@, coerces its +type to @a@, and substitutes @def@ for any @Nothing@ values if the source +column is optional. For non-nullable sources @def@ is unused. +-} +cast :: forall a. (Columnable a) => a -> T.Text -> Expr a +cast def name = Cast name def + +{- | Lenient coercion for assertedly non-nullable columns. +Substitutes @error@ for @Nothing@, so it will crash at evaluation time if +any @Nothing@ is actually encountered. For non-nullable and +fully-populated nullable columns no cost is paid. +-} +unsafeCast :: forall a. (Columnable a) => T.Text -> Expr a +unsafeCast name = Cast name (error "unsafeCast: unexpected Nothing in column") + toDouble :: (Columnable a, Real a) => Expr a -> Expr Double toDouble = Unary diff --git a/src/DataFrame/Internal/Expression.hs b/src/DataFrame/Internal/Expression.hs index 65d65e3..ef53ea5 100644 --- a/src/DataFrame/Internal/Expression.hs +++ b/src/DataFrame/Internal/Expression.hs @@ -56,6 +56,7 @@ data AggStrategy a b where data Expr a where Col :: (Columnable a) => T.Text -> Expr a + Cast :: (Columnable a) => T.Text -> a -> Expr a Lit :: (Columnable a) => a -> Expr a Unary :: (Columnable a, Columnable b) => UnaryOp b a -> Expr b -> Expr a @@ -228,6 +229,7 @@ instance (Floating a, Columnable a) => Floating (Expr a) where instance (Show a) => Show (Expr a) where show :: forall a. (Show a) => Expr a -> String show (Col name) = "(col @" ++ show (typeRep @a) ++ " " ++ show name ++ ")" + show (Cast name def) = "(cast @" ++ show (typeRep @a) ++ " " ++ show def ++ " " ++ show name ++ ")" show (Lit value) = "(lit (" ++ show value ++ "))" show (If cond l r) = "(ifThenElse " ++ show cond ++ " " ++ show l ++ " " ++ show r ++ ")" show (Unary op value) = "(" ++ T.unpack (unaryName op) ++ " " ++ show value ++ ")" @@ -239,6 +241,7 @@ instance (Show a) => Show (Expr a) where normalize :: (Eq a, Ord a, Show a, Typeable a) => Expr a -> Expr a normalize expr = case expr of Col name -> Col name + Cast name def -> Cast name def Lit val -> Lit val If cond th el -> If (normalize cond) (normalize th) (normalize el) Unary op e -> Unary op (normalize e) @@ -261,6 +264,7 @@ compareExpr e1 e2 = compare (exprKey e1) (exprKey e2) where exprKey :: Expr a -> String exprKey (Col name) = "0:" ++ T.unpack name + exprKey (Cast name _) = "0C:" ++ T.unpack name exprKey (Lit val) = "1:" ++ show val exprKey (If c t e) = "2:" ++ exprKey c ++ exprKey t ++ exprKey e exprKey (Unary op e) = "3:" ++ T.unpack (unaryName op) ++ exprKey e @@ -278,6 +282,7 @@ instance (Eq a, Columnable a) => Eq (Expr a) where Nothing -> False eqNormalized :: Expr a -> Expr a -> Bool eqNormalized (Col n1) (Col n2) = n1 == n2 + eqNormalized (Cast n1 d1) (Cast n2 d2) = n1 == n2 && d1 == d2 eqNormalized (Lit v1) (Lit v2) = v1 == v2 eqNormalized (If c1 t1 e1) (If c2 t2 e2) = c1 == c2 && t1 `exprEq` t2 && e1 `exprEq` e2 @@ -295,6 +300,7 @@ instance (Ord a, Columnable a) => Ord (Expr a) where compare :: Expr a -> Expr a -> Ordering compare e1 e2 = case (e1, e2) of (Col n1, Col n2) -> compare n1 n2 + (Cast n1 d1, Cast n2 d2) -> compare n1 n2 <> compare d1 d2 (Lit v1, Lit v2) -> compare v1 v2 (If c1 t1 e1', If c2 t2 e2') -> compare c1 c2 <> exprComp t1 t2 <> exprComp e1' e2' @@ -307,6 +313,8 @@ instance (Ord a, Columnable a) => Ord (Expr a) where -- Different constructors - compare by priority (Col _, _) -> LT (_, Col _) -> GT + (Cast _ _, _) -> LT + (_, Cast _ _) -> GT (Lit _, _) -> LT (_, Lit _) -> GT (Unary{}, _) -> LT @@ -334,6 +342,7 @@ replaceExpr new old expr = case testEquality (typeRep @b) (typeRep @c) of where replace' = case expr of (Col _) -> expr + (Cast _ _) -> expr (Lit _) -> expr (If cond l r) -> If (replaceExpr new old cond) (replaceExpr new old l) (replaceExpr new old r) @@ -343,6 +352,7 @@ replaceExpr new old expr = case testEquality (typeRep @b) (typeRep @c) of eSize :: Expr a -> Int eSize (Col _) = 1 +eSize (Cast _ _) = 1 eSize (Lit _) = 1 eSize (If c l r) = 1 + eSize c + eSize l + eSize r eSize (Unary _ e) = 1 + eSize e @@ -351,6 +361,7 @@ eSize (Agg strategy expr) = eSize expr + 1 getColumns :: Expr a -> [T.Text] getColumns (Col cName) = [cName] +getColumns (Cast name _) = [name] getColumns expr@(Lit _) = [] getColumns (If cond l r) = getColumns cond <> getColumns l <> getColumns r getColumns (Unary op value) = getColumns value @@ -366,6 +377,7 @@ prettyPrint = go 0 0 go :: Int -> Int -> Expr a -> String go depth prec expr = case expr of Col name -> T.unpack name + Cast name _ -> T.unpack name Lit value -> show value If cond t e -> let inner = diff --git a/src/DataFrame/Internal/Interpreter.hs b/src/DataFrame/Internal/Interpreter.hs index ad8de13..8876382 100644 --- a/src/DataFrame/Internal/Interpreter.hs +++ b/src/DataFrame/Internal/Interpreter.hs @@ -230,6 +230,141 @@ sliceGroups col os indices = case col of numGroups :: GroupedDataFrame -> Int numGroups gdf = VU.length (offsets gdf) - 1 +------------------------------------------------------------------------------- +-- promoteColumn: numeric / text coercion for Cast +------------------------------------------------------------------------------- + +{- | Coerce a column to type @a@, using @def@ for Nothing in Optional columns. +Numeric coercion handles Double, Float, and Int targets via the singleton +dispatch for floating and integral types. Text columns (String / T.Text) are +parsed via 'reads'; parse failures substitute @def@. Any other mismatch +returns 'Left TypeMismatchException'. +-} +promoteColumn :: + forall a. (Columnable a) => a -> Column -> Either DataFrameException Column +promoteColumn def col + | hasElemType @a col = Right col + | otherwise = + case testEquality (typeRep @a) (typeRep @Double) of + Just Refl -> promoteToDouble def col + Nothing -> + case testEquality (typeRep @a) (typeRep @Float) of + Just Refl -> promoteToFloat def col + Nothing -> + case testEquality (typeRep @a) (typeRep @Int) of + Just Refl -> promoteToInt def col + Nothing -> tryParse @a def col + +promoteToDouble :: Double -> Column -> Either DataFrameException Column +promoteToDouble def col = case col of + UnboxedColumn (v :: VU.Vector b) -> + case sFloating @b of + STrue -> Right $ fromUnboxedVector @Double (VU.map (realToFrac :: b -> Double) v) + SFalse -> case sIntegral @b of + STrue -> Right $ fromUnboxedVector @Double (VU.map (fromIntegral :: b -> Double) v) + SFalse -> castMismatch @b @Double + OptionalColumn (v :: V.Vector (Maybe b)) -> + case sFloating @b of + STrue -> + Right $ + fromUnboxedVector @Double + (VG.convert $ V.map (maybe def (realToFrac :: b -> Double)) v) + SFalse -> case sIntegral @b of + STrue -> + Right $ + fromUnboxedVector @Double + (VG.convert $ V.map (maybe def (fromIntegral :: b -> Double)) v) + SFalse -> tryParse @Double def col + BoxedColumn _ -> tryParse @Double def col + +promoteToFloat :: Float -> Column -> Either DataFrameException Column +promoteToFloat def col = case col of + UnboxedColumn (v :: VU.Vector b) -> + case sFloating @b of + STrue -> Right $ fromUnboxedVector @Float (VU.map (realToFrac :: b -> Float) v) + SFalse -> case sIntegral @b of + STrue -> Right $ fromUnboxedVector @Float (VU.map (fromIntegral :: b -> Float) v) + SFalse -> castMismatch @b @Float + OptionalColumn (v :: V.Vector (Maybe b)) -> + case sFloating @b of + STrue -> + Right $ + fromUnboxedVector @Float + (VG.convert $ V.map (maybe def (realToFrac :: b -> Float)) v) + SFalse -> case sIntegral @b of + STrue -> + Right $ + fromUnboxedVector @Float + (VG.convert $ V.map (maybe def (fromIntegral :: b -> Float)) v) + SFalse -> tryParse @Float def col + BoxedColumn _ -> tryParse @Float def col + +promoteToInt :: Int -> Column -> Either DataFrameException Column +promoteToInt def col = case col of + UnboxedColumn (v :: VU.Vector b) -> + case sFloating @b of + STrue -> + Right $ + fromUnboxedVector @Int + (VU.map (round . (realToFrac :: b -> Double)) v) + SFalse -> case sIntegral @b of + STrue -> Right $ fromUnboxedVector @Int (VU.map (fromIntegral :: b -> Int) v) + SFalse -> castMismatch @b @Int + OptionalColumn (v :: V.Vector (Maybe b)) -> + case sFloating @b of + STrue -> + Right $ + fromUnboxedVector @Int + (VG.convert $ V.map (maybe def (round . (realToFrac :: b -> Double))) v) + SFalse -> case sIntegral @b of + STrue -> + Right $ + fromUnboxedVector @Int + (VG.convert $ V.map (maybe def (fromIntegral :: b -> Int)) v) + SFalse -> tryParse @Int def col + BoxedColumn _ -> tryParse @Int def col + +tryParse :: + forall a. (Columnable a) => a -> Column -> Either DataFrameException Column +tryParse def col = case col of + BoxedColumn (v :: V.Vector b) -> + case testEquality (typeRep @b) (typeRep @String) of + Just Refl -> Right $ fromVector @a $ V.map (parseValue def) v + Nothing -> + case testEquality (typeRep @b) (typeRep @T.Text) of + Just Refl -> Right $ fromVector @a $ V.map (parseValue def . T.unpack) v + Nothing -> castMismatch @b @a + OptionalColumn (v :: V.Vector (Maybe b)) -> + case testEquality (typeRep @b) (typeRep @String) of + Just Refl -> Right $ fromVector @a $ V.map (maybe def (parseValue def)) v + Nothing -> + case testEquality (typeRep @b) (typeRep @T.Text) of + Just Refl -> + Right $ + fromVector @a $ + V.map (maybe def (parseValue def . T.unpack)) v + Nothing -> castMismatch @b @a + UnboxedColumn (_ :: VU.Vector b) -> castMismatch @b @a + +parseValue :: (Read a) => a -> String -> a +parseValue def s = case reads s of + [(x, "")] -> x + _ -> def + +castMismatch :: + forall src tgt. + (Typeable src, Typeable tgt) => + Either DataFrameException Column +castMismatch = + Left $ + TypeMismatchException + MkTypeErrorContext + { userType = Right (typeRep @tgt) + , expectedType = Right (typeRep @src) + , callingFunctionName = Just "cast" + , errorColumnName = Nothing + } + ------------------------------------------------------------------------------- -- eval: the unified interpreter ------------------------------------------------------------------------------- @@ -264,6 +399,26 @@ eval (GroupCtx gdf) (Col name) = ( Group (sliceGroups c (offsets gdf) (valueIndices gdf)) ) +-- Cast ------------------------------------------------------------------- + +eval (FlatCtx df) (Cast name def) = + case getColumn name df of + Nothing -> + Left $ + ColumnNotFoundException name "" (M.keys $ columnIndices df) + Just c -> Flat <$> promoteColumn @a def c +eval (GroupCtx gdf) (Cast name def) = + case getColumn name (fullDataframe gdf) of + Nothing -> + Left $ + ColumnNotFoundException + name + "" + (M.keys $ columnIndices $ fullDataframe gdf) + Just c -> do + promoted <- promoteColumn @a def c + Right $ Group (sliceGroups promoted (offsets gdf) (valueIndices gdf)) + -- Unary ------------------------------------------------------------------ eval ctx expr@(Unary (op :: UnaryOp b a) inner) = addContext expr $ do