Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept of "side channel" for diagnostics #639

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cabal.project
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@ packages: semantic
semantic-tags
semantic-tsx
semantic-typescript
-- ../../tree-sitter/haskell-tree-sitter/tree-sitter

source-repository-package
type: git
location: https://github.com/antitypical/fused-syntax.git
tag: 4a383d57c8fd7592f54a33f43eb9666314a6e80e

source-repository-package
type: git
location: https://github.com/alanz/haskell-tree-sitter.git
tag: ec8a82021017f644c77156c33a249073d013f73b
subdir: tree-sitter
6 changes: 6 additions & 0 deletions cabal.project.ci
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ source-repository-package
location: https://github.com/antitypical/fused-syntax.git
tag: 4a383d57c8fd7592f54a33f43eb9666314a6e80e

source-repository-package
type: git
location: https://github.com/alanz/haskell-tree-sitter.git
tag: 8f6d94e6bbb7035851492160fa6e642c8612d2b
subdir: tree-sitter

-- Treat warnings as errors for CI builds
package semantic
ghc-options: -Werror
Expand Down
164 changes: 153 additions & 11 deletions semantic-ast/src/AST/Unmarshal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverloadedStrings #-}

module AST.Unmarshal
( parseByteString
, UnmarshalState(..)
, UnmarshalDiagnostics(..)
, TSDiagnostic(..)
, parseDiagnostics
, UnmarshalError(..)
, FieldName(..)
, Unmarshal(..)
Expand All @@ -29,9 +33,11 @@ module AST.Unmarshal

import AST.Token as TS
import AST.Parse
import Control.Carrier.Reader
import Control.Carrier.State.Strict
import Control.Exception
import Control.Monad (void)
import Control.Monad.IO.Class
import Data.Attoparsec.Text as Attoparsec (Parser, char, takeWhile1, string, many', endOfInput, decimal, choice, parseOnly)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.Coerce
Expand All @@ -44,6 +50,7 @@ import qualified Data.Text as Text
import Data.Text.Encoding
import Data.Text.Encoding.Error (lenientDecode)
import Foreign.C.String
import Foreign.Marshal.Alloc (alloca, free)
import Foreign.Marshal.Array
import Foreign.Marshal.Utils
import Foreign.Ptr
Expand All @@ -58,29 +65,54 @@ import TreeSitter.Language as TS
import TreeSitter.Node as TS
import TreeSitter.Parser as TS
import TreeSitter.Tree as TS
import Control.Applicative ((<|>))

-- Parse source code and produce AST
parseByteString :: (Unmarshal t, UnmarshalAnn a) => Ptr TS.Language -> ByteString -> IO (Either String (t a))
parseByteString :: (Unmarshal t, UnmarshalAnn a) => Ptr TS.Language -> ByteString -> IO (Either String (UnmarshalDiagnostics, (t a)))
parseByteString language bytestring = withParser language $ \ parser -> withParseTree parser bytestring $ \ treePtr ->
if treePtr == nullPtr then
pure (Left "error: didn't get a root node")
else
withRootNode treePtr $ \ rootPtr ->
withCursor (castPtr rootPtr) $ \ cursor ->
(Right <$> runReader (UnmarshalState bytestring cursor) (liftIO (peek rootPtr) >>= unmarshalNode))
`catch` (pure . Left . getUnmarshalError)
else do
r <-
withRootNode treePtr $ \ rootPtr ->
withCursor (castPtr rootPtr) $ \ cursor ->
-- (Right <$> runReader (UnmarshalState bytestring cursor) (liftIO (peek rootPtr) >>= unmarshalNode))
(Right <$> runState (UnmarshalState bytestring cursor mempty) (liftIO (peek rootPtr) >>= unmarshalNode))
`catch` (pure . Left . getUnmarshalError)
case r of
Left e -> pure $ Left e
Right (s, res) -> pure $ Right (diagnostics s, res)

newtype UnmarshalError = UnmarshalError { getUnmarshalError :: String }
deriving (Show)

instance Exception UnmarshalError

-- newtype UnmarshalDiagnostics = UnmarshalDiagnostics [(Loc, String)]
newtype UnmarshalDiagnostics = UnmarshalDiagnostics [(Range, [((Int,Int), TSDiagnostic)])]
deriving (Show, Eq)

instance Semigroup UnmarshalDiagnostics where
UnmarshalDiagnostics a <> UnmarshalDiagnostics b = UnmarshalDiagnostics (a <> b)

instance Monoid UnmarshalDiagnostics where
mempty = UnmarshalDiagnostics []

data UnmarshalState = UnmarshalState
{ source :: {-# UNPACK #-} !ByteString
, cursor :: {-# UNPACK #-} !(Ptr Cursor)
, diagnostics :: !UnmarshalDiagnostics
}

type MatchM = ReaderC UnmarshalState IO
-- type MatchM = ReaderC UnmarshalState IO
type MatchM = StateC UnmarshalState IO

-- runReader :: r -> ReaderC r m a -> m a
-- newtype ReaderC r m a

-- runState :: s -> StateC s m a -> m (s, a)
-- evalState :: forall s m a. Functor m => s -> StateC s m a -> m a
-- newtype StateC s m a

newtype Match t = Match
{ runMatch :: forall a . UnmarshalAnn a => Node -> MatchM (t a)
Expand Down Expand Up @@ -142,10 +174,26 @@ class SymbolMatching t => Unmarshal t where
default matchers :: (Generic1 t, GUnmarshal (Rep1 t)) => B (Int, Match t)
matchers = foldMap (singleton . (, match)) (matchedSymbols (Proxy @t))
where match = Match $ \ node -> do
cursor <- asks cursor
cursor <- gets cursor
goto cursor (nodeTSNode node)
-- Note: checkDiagnostics could be made optional, for
-- batch usage
checkDiagnostics node
fmap to1 (gunmarshalNode node)

-- | Check if the Node has any tree-sitter problems, such as being an
-- ERROR or MISSING or UNEXPECTED node
checkDiagnostics :: Node -> MatchM ()
checkDiagnostics node = do
(Loc loc _) <- unmarshalAnn @Loc node
dds <- liftIO . alloca $ (\p -> poke p (nodeTSNode node) >> ts_node_string_diagnostics_p p)
str <- liftIO $ peekCString dds
liftIO $ free dds
let dd = [(loc, parseDiagnostics (Text.pack str)) | not (null str)]

modify (\s -> s { diagnostics = diagnostics s <> UnmarshalDiagnostics dd })


instance (Unmarshal f, Unmarshal g) => Unmarshal (f :+: g) where
matchers = fmap (fmap (hoist L1)) matchers <> fmap (fmap (hoist R1)) matchers

Expand Down Expand Up @@ -175,7 +223,7 @@ instance UnmarshalAnn () where
instance UnmarshalAnn Text.Text where
unmarshalAnn node = do
range <- unmarshalAnn node
asks (decodeUtf8With lenientDecode . slice range . source)
gets (decodeUtf8With lenientDecode . slice range . source)

-- | Instance for pairs of annotations
instance (UnmarshalAnn a, UnmarshalAnn b) => UnmarshalAnn (a,b) where
Expand Down Expand Up @@ -362,7 +410,7 @@ instance Unmarshal t => GUnmarshalData (Rec1 t) where

-- For product datatypes:
instance (GUnmarshalProduct f, GUnmarshalProduct g) => GUnmarshalData (f :*: g) where
gunmarshalNode' datatypeName node = asks cursor >>= flip getFields node >>= gunmarshalProductNode @(f :*: g) datatypeName node
gunmarshalNode' datatypeName node = gets cursor >>= flip getFields node >>= gunmarshalProductNode @(f :*: g) datatypeName node


-- | Generically unmarshal products
Expand Down Expand Up @@ -421,3 +469,97 @@ instance (GHasAnn ann l, GHasAnn ann r) => GHasAnn ann (l :+: r) where

instance {-# OVERLAPPABLE #-} HasField "ann" (t ann) ann => GHasAnn ann t where
gann = getField @"ann"

-- ---------------------------------------------------------------------
-- Perhaps this belongs in its own module


-- [(Range {start = 11, end = 23},"((16,19) (ERROR)),((24,24) (MISSING \".\")),")]

data TSDiagnostic = TSDError
| TSDMissing Text.Text
| TSDUnexpected Text.Text
deriving (Eq,Show)

parseDiagnostics :: Text.Text -> [((Int,Int), TSDiagnostic)]
parseDiagnostics str = r
where
r = case parseOnly posdiagnostics str of
-- TODO: proper handling
Left err -> [((1,0), TSDUnexpected ("parseDiagnostics:" <> Text.pack err <> "parsing [" <> str <> "]"))]
Right ds -> ds

posdiagnostics :: Attoparsec.Parser [((Int,Int), TSDiagnostic)]
posdiagnostics = do
xs <- many' posdiagnostic
void endOfInput
pure xs

-- "((16,19) (ERROR)),((24,24) (MISSING \".\")),")]
-- ((15,15) (MISSING {_raw_atom}))
posdiagnostic :: Attoparsec.Parser ((Int,Int), TSDiagnostic)
posdiagnostic = do
void $ char '('
pos <- pos
void $ char ' '
d <- diag
void $ char ')'
void $ char ','
pure (pos,d)

pos :: Attoparsec.Parser (Int,Int)
pos = do
void $ char '('
l <- decimal
void $ char ','
c <- decimal
void $ char ')'
pure (l,c)

diag :: Attoparsec.Parser TSDiagnostic
diag = do
void $ char '('
d <- choice [perror,punexpected,pmissing]
void $ char ')'
pure d

perror :: Attoparsec.Parser TSDiagnostic
perror = do
void $ string "ERROR"
pure TSDError

pmissing :: Attoparsec.Parser TSDiagnostic
pmissing = do
void $ string "MISSING"
void $ char ' '
s <- pquoted
pure (TSDMissing s)

punexpected :: Attoparsec.Parser TSDiagnostic
punexpected = do
void $ string "UNEXPECTED"
void $ char ' '
s <- pquoted
pure (TSDUnexpected s)

pquoted :: Attoparsec.Parser Text.Text
pquoted =
pquotedString
<|> pbracedString

pquotedString :: Attoparsec.Parser Text.Text
pquotedString = do
void $ char '"'
s <- takeWhile1 (/= '"')
void $ char '"'
pure s

-- TODO: this should probably be an SEXP parser. But we wrap the
-- missng part in braces in the tree-sitter C part, so should work
-- unless there is a '}' in the quoted fragment.
pbracedString :: Attoparsec.Parser Text.Text
pbracedString = do
void $ char '{'
s <- takeWhile1 (/= '}')
void $ char '}'
pure s
5 changes: 3 additions & 2 deletions semantic/src/Parsing/TreeSitter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ module Parsing.TreeSitter
, parseToPreciseAST
) where

import Control.Carrier.Reader
import Control.Carrier.State.Strict
-- import Control.Carrier.Reader
import Control.Exception as Exc
import Control.Monad.IO.Class
import Foreign
Expand Down Expand Up @@ -51,7 +52,7 @@ parseToPreciseAST
parseToPreciseAST parseTimeout unmarshalTimeout language blob = runParse parseTimeout language blob $ \ rootPtr ->
withTimeout $
TS.withCursor (castPtr rootPtr) $ \ cursor ->
runReader (TS.UnmarshalState (Source.bytes (blobSource blob)) cursor) (liftIO (peek rootPtr) >>= TS.unmarshalNode)
runState (TS.UnmarshalState (Source.bytes (blobSource blob)) cursor mempty) (liftIO (peek rootPtr) >>= TS.unmarshalNode)
`Exc.catch` (Exc.throw . UnmarshalFailure . TS.getUnmarshalError)
where
withTimeout :: IO a -> IO a
Expand Down