cassandra-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jmcken...@apache.org
Subject [18/32] cassandra git commit: 2.2 commit for CASSANDRA-9160
Date Wed, 24 Jun 2015 16:15:11 GMT
http://git-wip-us.apache.org/repos/asf/cassandra/blob/01115f72/test/unit/org/apache/cassandra/cql3/validation/entities/UFAuthTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/validation/entities/UFAuthTest.java b/test/unit/org/apache/cassandra/cql3/validation/entities/UFAuthTest.java
new file mode 100644
index 0000000..498f0dd
--- /dev/null
+++ b/test/unit/org/apache/cassandra/cql3/validation/entities/UFAuthTest.java
@@ -0,0 +1,728 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.cql3.validation.entities;
+
+import java.lang.reflect.Field;
+import java.util.*;
+
+import com.google.common.base.Joiner;
+import com.google.common.collect.ImmutableSet;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.auth.*;
+import org.apache.cassandra.config.DatabaseDescriptor;
+import org.apache.cassandra.cql3.Attributes;
+import org.apache.cassandra.cql3.CQLStatement;
+import org.apache.cassandra.cql3.QueryProcessor;
+import org.apache.cassandra.cql3.functions.Function;
+import org.apache.cassandra.cql3.functions.FunctionName;
+import org.apache.cassandra.cql3.functions.Functions;
+import org.apache.cassandra.cql3.statements.BatchStatement;
+import org.apache.cassandra.cql3.statements.ModificationStatement;
+import org.apache.cassandra.cql3.CQLTester;
+import org.apache.cassandra.exceptions.*;
+import org.apache.cassandra.service.ClientState;
+import org.apache.cassandra.utils.Pair;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class UFAuthTest extends CQLTester
+{
+    private static final Logger logger = LoggerFactory.getLogger(UFAuthTest.class);
+
+    String roleName = "test_role";
+    AuthenticatedUser user;
+    RoleResource role;
+    ClientState clientState;
+
+    @BeforeClass
+    public static void setupAuthorizer()
+    {
+        try
+        {
+            IAuthorizer authorizer = new StubAuthorizer();
+            Field authorizerField = DatabaseDescriptor.class.getDeclaredField("authorizer");
+            authorizerField.setAccessible(true);
+            authorizerField.set(null, authorizer);
+            DatabaseDescriptor.setPermissionsValidity(0);
+        }
+        catch (IllegalAccessException | NoSuchFieldException e)
+        {
+            throw new RuntimeException(e);
+        }
+    }
+
+    @Before
+    public void setup() throws Throwable
+    {
+        ((StubAuthorizer) DatabaseDescriptor.getAuthorizer()).clear();
+        setupClientState();
+        setupTable("CREATE TABLE %s (k int, v1 int, v2 int, PRIMARY KEY (k, v1))");
+    }
+
+    @Test
+    public void functionInSelection() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("SELECT k, %s FROM %s WHERE k = 1;",
+                                   functionCall(functionName),
+                                   KEYSPACE + "." + currentTable());
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInSelectPKRestriction() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("SELECT * FROM %s WHERE k = %s",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInSelectClusteringRestriction() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("SELECT * FROM %s WHERE k = 0 AND v1 = %s",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInSelectInRestriction() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("SELECT * FROM %s WHERE k IN (%s, %s)",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInSelectMultiColumnInRestriction() throws Throwable
+    {
+        setupTable("CREATE TABLE %s (k int, v1 int, v2 int, v3 int, PRIMARY KEY (k, v1, v2))");
+        String functionName = createSimpleFunction();
+        String cql = String.format("SELECT * FROM %s WHERE k=0 AND (v1, v2) IN ((%s, %s))",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInSelectMultiColumnEQRestriction() throws Throwable
+    {
+        setupTable("CREATE TABLE %s (k int, v1 int, v2 int, v3 int, PRIMARY KEY (k, v1, v2))");
+        String functionName = createSimpleFunction();
+        String cql = String.format("SELECT * FROM %s WHERE k=0 AND (v1, v2) = (%s, %s)",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInSelectMultiColumnSliceRestriction() throws Throwable
+    {
+        setupTable("CREATE TABLE %s (k int, v1 int, v2 int, v3 int, PRIMARY KEY (k, v1, v2))");
+        String functionName = createSimpleFunction();
+        String cql = String.format("SELECT * FROM %s WHERE k=0 AND (v1, v2) < (%s, %s)",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInSelectTokenEQRestriction() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("SELECT * FROM %s WHERE token(k) = token(%s)",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInSelectTokenSliceRestriction() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("SELECT * FROM %s WHERE token(k) < token(%s)",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInPKForInsert() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("INSERT INTO %s (k, v1, v2) VALUES (%s, 0, 0)",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInClusteringValuesForInsert() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("INSERT INTO %s (k, v1, v2) VALUES (0, %s, 0)",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInPKForDelete() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("DELETE FROM %s WHERE k = %s",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInClusteringValuesForDelete() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("DELETE FROM %s WHERE k = 0 AND v1 = %s",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void testBatchStatement() throws Throwable
+    {
+        List<ModificationStatement> statements = new ArrayList<>();
+        List<String> functions = new ArrayList<>();
+        for (int i = 0; i < 3; i++)
+        {
+            String functionName = createSimpleFunction();
+            ModificationStatement stmt =
+            (ModificationStatement) getStatement(String.format("INSERT INTO %s (k, v1, v2) " +
+                                                               "VALUES (%s, %s, %s)",
+                                                               KEYSPACE + "." + currentTable(),
+                                                               i, i, functionCall(functionName)));
+            functions.add(functionName);
+            statements.add(stmt);
+        }
+        BatchStatement batch = new BatchStatement(-1, BatchStatement.Type.LOGGED, statements, Attributes.none());
+        assertUnauthorized(batch, functions);
+
+        grantExecuteOnFunction(functions.get(0));
+        assertUnauthorized(batch, functions.subList(1, functions.size()));
+
+        grantExecuteOnFunction(functions.get(1));
+        assertUnauthorized(batch, functions.subList(2, functions.size()));
+
+        grantExecuteOnFunction(functions.get(2));
+        batch.checkAccess(clientState);
+    }
+
+    @Test
+    public void testNestedFunctions() throws Throwable
+    {
+        String innerFunctionName = createSimpleFunction();
+        String outerFunctionName = createFunction("int",
+                                                  "CREATE FUNCTION %s(input int) " +
+                                                  " CALLED ON NULL INPUT" +
+                                                  " RETURNS int" +
+                                                  " LANGUAGE java" +
+                                                  " AS 'return Integer.valueOf(0);'");
+        assertPermissionsOnNestedFunctions(innerFunctionName, outerFunctionName);
+    }
+
+    @Test
+    public void functionInStaticColumnRestrictionInSelect() throws Throwable
+    {
+        setupTable("CREATE TABLE %s (k int, s int STATIC, v1 int, v2 int, PRIMARY KEY(k, v1))");
+        String functionName = createSimpleFunction();
+        String cql = String.format("SELECT k FROM %s WHERE k = 0 AND s = %s",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInRegularCondition() throws Throwable
+    {
+        String functionName = createSimpleFunction();
+        String cql = String.format("UPDATE %s SET v2 = 0 WHERE k = 0 AND v1 = 0 IF v2 = %s",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+    @Test
+    public void functionInStaticColumnCondition() throws Throwable
+    {
+        setupTable("CREATE TABLE %s (k int, s int STATIC, v1 int, v2 int, PRIMARY KEY(k, v1))");
+        String functionName = createSimpleFunction();
+        String cql = String.format("UPDATE %s SET v2 = 0 WHERE k = 0 AND v1 = 0 IF s = %s",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInCollectionLiteralCondition() throws Throwable
+    {
+        setupTable("CREATE TABLE %s (k int, v1 int, m_val map<int, int>, PRIMARY KEY(k))");
+        String functionName = createSimpleFunction();
+        String cql = String.format("UPDATE %s SET v1 = 0 WHERE k = 0 IF m_val = {%s : %s}",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void functionInCollectionElementCondition() throws Throwable
+    {
+        setupTable("CREATE TABLE %s (k int, v1 int, m_val map<int, int>, PRIMARY KEY(k))");
+        String functionName = createSimpleFunction();
+        String cql = String.format("UPDATE %s SET v1 = 0 WHERE k = 0 IF m_val[%s] = %s",
+                                   KEYSPACE + "." + currentTable(),
+                                   functionCall(functionName),
+                                   functionCall(functionName));
+        assertPermissionsOnFunction(cql, functionName);
+    }
+
+    @Test
+    public void systemFunctionsRequireNoExplicitPrivileges() throws Throwable
+    {
+        // with terminal arguments, so evaluated at prepare time
+        String cql = String.format("UPDATE %s SET v2 = 0 WHERE k = blobasint(intasblob(0))",
+                                   KEYSPACE + "." + currentTable());
+        getStatement(cql).checkAccess(clientState);
+
+        // with non-terminal arguments, so evaluated at execution
+        String functionName = createSimpleFunction();
+        grantExecuteOnFunction(functionName);
+        cql = String.format("UPDATE %s SET v2 = 0 WHERE k = blobasint(intasblob(%s))",
+                            KEYSPACE + "." + currentTable(),
+                            functionCall(functionName));
+        getStatement(cql).checkAccess(clientState);
+    }
+
+    @Test
+    public void requireExecutePermissionOnComponentFunctionsWhenDefiningAggregate() throws Throwable
+    {
+        String sFunc = createSimpleStateFunction();
+        String fFunc = createSimpleFinalFunction();
+        // aside from the component functions, we need CREATE on the keyspace's functions
+        DatabaseDescriptor.getAuthorizer().grant(AuthenticatedUser.SYSTEM_USER,
+                                                 ImmutableSet.of(Permission.CREATE),
+                                                 FunctionResource.keyspace(KEYSPACE),
+                                                 role);
+        String aggDef = String.format(aggregateCql(sFunc, fFunc),
+                                      KEYSPACE + ".aggregate_for_permissions_test");
+
+        assertUnauthorized(aggDef, sFunc, "int, int");
+        grantExecuteOnFunction(sFunc);
+
+        assertUnauthorized(aggDef, fFunc, "int");
+        grantExecuteOnFunction(fFunc);
+
+        getStatement(aggDef).checkAccess(clientState);
+    }
+
+    @Test
+    public void revokeExecutePermissionsOnAggregateComponents() throws Throwable
+    {
+        String sFunc = createSimpleStateFunction();
+        String fFunc = createSimpleFinalFunction();
+        String aggDef = aggregateCql(sFunc, fFunc);
+        grantExecuteOnFunction(sFunc);
+        grantExecuteOnFunction(fFunc);
+
+        String aggregate = createAggregate(KEYSPACE, "int", aggDef);
+        grantExecuteOnFunction(aggregate);
+
+        String cql = String.format("SELECT %s(v1) FROM %s",
+                                   aggregate,
+                                   KEYSPACE + "." + currentTable());
+        getStatement(cql).checkAccess(clientState);
+
+        // check that revoking EXECUTE permission on any one of the
+        // component functions means we lose the ability to execute it
+        revokeExecuteOnFunction(aggregate);
+        assertUnauthorized(cql, aggregate, "int");
+        grantExecuteOnFunction(aggregate);
+        getStatement(cql).checkAccess(clientState);
+
+        revokeExecuteOnFunction(sFunc);
+        assertUnauthorized(cql, sFunc, "int, int");
+        grantExecuteOnFunction(sFunc);
+        getStatement(cql).checkAccess(clientState);
+
+        revokeExecuteOnFunction(fFunc);
+        assertUnauthorized(cql, fFunc, "int");
+        grantExecuteOnFunction(fFunc);
+        getStatement(cql).checkAccess(clientState);
+    }
+
+    @Test
+    public void functionWrappingAggregate() throws Throwable
+    {
+        String outerFunc = createFunction("int",
+                                          "CREATE FUNCTION %s(input int) " +
+                                          "CALLED ON NULL INPUT " +
+                                          "RETURNS int " +
+                                          "LANGUAGE java " +
+                                          "AS 'return input;'");
+
+        String sFunc = createSimpleStateFunction();
+        String fFunc = createSimpleFinalFunction();
+        String aggDef = aggregateCql(sFunc, fFunc);
+        grantExecuteOnFunction(sFunc);
+        grantExecuteOnFunction(fFunc);
+
+        String aggregate = createAggregate(KEYSPACE, "int", aggDef);
+
+        String cql = String.format("SELECT %s(%s(v1)) FROM %s",
+                                   outerFunc,
+                                   aggregate,
+                                   KEYSPACE + "." + currentTable());
+
+        assertUnauthorized(cql, outerFunc, "int");
+        grantExecuteOnFunction(outerFunc);
+
+        assertUnauthorized(cql, aggregate, "int");
+        grantExecuteOnFunction(aggregate);
+
+        getStatement(cql).checkAccess(clientState);
+    }
+
+    @Test
+    public void aggregateWrappingFunction() throws Throwable
+    {
+        String innerFunc = createFunction("int",
+                                          "CREATE FUNCTION %s(input int) " +
+                                          "CALLED ON NULL INPUT " +
+                                          "RETURNS int " +
+                                          "LANGUAGE java " +
+                                          "AS 'return input;'");
+
+        String sFunc = createSimpleStateFunction();
+        String fFunc = createSimpleFinalFunction();
+        String aggDef = aggregateCql(sFunc, fFunc);
+        grantExecuteOnFunction(sFunc);
+        grantExecuteOnFunction(fFunc);
+
+        String aggregate = createAggregate(KEYSPACE, "int", aggDef);
+
+        String cql = String.format("SELECT %s(%s(v1)) FROM %s",
+                                   aggregate,
+                                   innerFunc,
+                                   KEYSPACE + "." + currentTable());
+
+        assertUnauthorized(cql, aggregate, "int");
+        grantExecuteOnFunction(aggregate);
+
+        assertUnauthorized(cql, innerFunc, "int");
+        grantExecuteOnFunction(innerFunc);
+
+        getStatement(cql).checkAccess(clientState);
+    }
+
+    private void assertPermissionsOnNestedFunctions(String innerFunction, String outerFunction) throws Throwable
+    {
+        String cql = String.format("SELECT k, %s FROM %s WHERE k=0",
+                                   functionCall(outerFunction, functionCall(innerFunction)),
+                                   KEYSPACE + "." + currentTable());
+        // fail fast with an UAE on the first function
+        assertUnauthorized(cql, outerFunction, "int");
+        grantExecuteOnFunction(outerFunction);
+
+        // after granting execute on the first function, still fail due to the inner function
+        assertUnauthorized(cql, innerFunction, "");
+        grantExecuteOnFunction(innerFunction);
+
+        // now execution of both is permitted
+        getStatement(cql).checkAccess(clientState);
+    }
+
+    private void assertPermissionsOnFunction(String cql, String functionName) throws Throwable
+    {
+        assertPermissionsOnFunction(cql, functionName, "");
+    }
+
+    private void assertPermissionsOnFunction(String cql, String functionName, String argTypes) throws Throwable
+    {
+        assertUnauthorized(cql, functionName, argTypes);
+        grantExecuteOnFunction(functionName);
+        getStatement(cql).checkAccess(clientState);
+    }
+
+    private void assertUnauthorized(BatchStatement batch, Iterable<String> functionNames) throws Throwable
+    {
+        try
+        {
+            batch.checkAccess(clientState);
+            fail("Expected an UnauthorizedException, but none was thrown");
+        }
+        catch (UnauthorizedException e)
+        {
+            String functions = String.format("(%s)", Joiner.on("|").join(functionNames));
+            assertTrue(e.getLocalizedMessage()
+                        .matches(String.format("User %s has no EXECUTE permission on <function %s\\(\\)> or any of its parents",
+                                               roleName,
+                                               functions)));
+        }
+    }
+
+    private void assertUnauthorized(String cql, String functionName, String argTypes) throws Throwable
+    {
+        try
+        {
+            getStatement(cql).checkAccess(clientState);
+            fail("Expected an UnauthorizedException, but none was thrown");
+        }
+        catch (UnauthorizedException e)
+        {
+            assertEquals(String.format("User %s has no EXECUTE permission on <function %s(%s)> or any of its parents",
+                                       roleName,
+                                       functionName,
+                                       argTypes),
+                         e.getLocalizedMessage());
+        }
+    }
+
+    private void grantExecuteOnFunction(String functionName)
+    {
+            DatabaseDescriptor.getAuthorizer().grant(AuthenticatedUser.SYSTEM_USER,
+                                                     ImmutableSet.of(Permission.EXECUTE),
+                                                     functionResource(functionName),
+                                                     role);
+    }
+
+    private void revokeExecuteOnFunction(String functionName)
+    {
+        DatabaseDescriptor.getAuthorizer().revoke(AuthenticatedUser.SYSTEM_USER,
+                                                  ImmutableSet.of(Permission.EXECUTE),
+                                                  functionResource(functionName),
+                                                  role);
+    }
+
+    void setupClientState()
+    {
+
+        try
+        {
+            role = RoleResource.role(roleName);
+            // use reflection to set the logged in user so that we don't need to
+            // bother setting up an IRoleManager
+            user = new AuthenticatedUser(roleName);
+            clientState = ClientState.forInternalCalls();
+            Field userField = ClientState.class.getDeclaredField("user");
+            userField.setAccessible(true);
+            userField.set(clientState, user);
+        }
+        catch (IllegalAccessException | NoSuchFieldException e)
+        {
+            throw new RuntimeException(e);
+        }
+    }
+
+    private void setupTable(String tableDef) throws Throwable
+    {
+        createTable(tableDef);
+        // test user needs SELECT & MODIFY on the table regardless of permissions on any function
+        DatabaseDescriptor.getAuthorizer().grant(AuthenticatedUser.SYSTEM_USER,
+                                                 ImmutableSet.of(Permission.SELECT, Permission.MODIFY),
+                                                 DataResource.table(KEYSPACE, currentTable()),
+                                                 RoleResource.role(user.getName()));
+    }
+
+    private String aggregateCql(String sFunc, String fFunc)
+    {
+        return "CREATE AGGREGATE %s(int) " +
+               "SFUNC " + shortFunctionName(sFunc) + " " +
+               "STYPE int " +
+               "FINALFUNC " + shortFunctionName(fFunc) + " " +
+               "INITCOND 0";
+    }
+
+    private String createSimpleStateFunction() throws Throwable
+    {
+        return createFunction("int, int",
+                              "CREATE FUNCTION %s(a int, b int) " +
+                              "CALLED ON NULL INPUT " +
+                              "RETURNS int " +
+                              "LANGUAGE java " +
+                              "AS 'return Integer.valueOf( (a != null ? a.intValue() : 0 ) + b.intValue());'");
+    }
+
+    private String createSimpleFinalFunction() throws Throwable
+    {
+        return createFunction("int",
+                              "CREATE FUNCTION %s(a int) " +
+                              "CALLED ON NULL INPUT " +
+                              "RETURNS int " +
+                              "LANGUAGE java " +
+                              "AS 'return a;'");
+    }
+
+    private String createSimpleFunction() throws Throwable
+    {
+        return createFunction("",
+                              "CREATE FUNCTION %s() " +
+                              "  CALLED ON NULL INPUT " +
+                              "  RETURNS int " +
+                              "  LANGUAGE java " +
+                              "  AS 'return Integer.valueOf(0);'");
+    }
+
+    private String createFunction(String argTypes, String functionDef) throws Throwable
+    {
+        return createFunction(KEYSPACE, argTypes, functionDef);
+    }
+
+    private CQLStatement getStatement(String cql)
+    {
+        return QueryProcessor.getStatement(cql, clientState).statement;
+    }
+
+    private FunctionResource functionResource(String functionName)
+    {
+        // Note that this is somewhat brittle as it assumes that function names are
+        // truly unique. As such, it will break in the face of overloading.
+        // It is here to avoid having to duplicate the functionality of CqlParser
+        // for transforming cql types into AbstractTypes
+        FunctionName fn = parseFunctionName(functionName);
+        List<Function> functions = Functions.find(fn);
+        assertEquals(String.format("Expected a single function definition for %s, but found %s",
+                                   functionName,
+                                   functions.size()),
+                     1, functions.size());
+        return FunctionResource.function(fn.keyspace, fn.name, functions.get(0).argTypes());
+    }
+
+    private String functionCall(String functionName, String...args)
+    {
+        return String.format("%s(%s)", functionName, Joiner.on(",").join(args));
+    }
+
+    static class StubAuthorizer implements IAuthorizer
+    {
+        Map<Pair<String, IResource>, Set<Permission>> userPermissions = new HashMap<>();
+
+        private void clear()
+        {
+            userPermissions.clear();
+        }
+
+        public Set<Permission> authorize(AuthenticatedUser user, IResource resource)
+        {
+            Pair<String, IResource> key = Pair.create(user.getName(), resource);
+            Set<Permission> perms = userPermissions.get(key);
+            return perms != null ? perms : Collections.<Permission>emptySet();
+        }
+
+        public void grant(AuthenticatedUser performer,
+                          Set<Permission> permissions,
+                          IResource resource,
+                          RoleResource grantee) throws RequestValidationException, RequestExecutionException
+        {
+            Pair<String, IResource> key = Pair.create(grantee.getRoleName(), resource);
+            Set<Permission> perms = userPermissions.get(key);
+            if (null == perms)
+            {
+                perms = new HashSet<>();
+                userPermissions.put(key, perms);
+            }
+            perms.addAll(permissions);
+        }
+
+        public void revoke(AuthenticatedUser performer,
+                           Set<Permission> permissions,
+                           IResource resource,
+                           RoleResource revokee) throws RequestValidationException, RequestExecutionException
+        {
+            Pair<String, IResource> key = Pair.create(revokee.getRoleName(), resource);
+            Set<Permission> perms = userPermissions.get(key);
+            if (null != perms)
+                perms.removeAll(permissions);
+            if (perms.isEmpty())
+                userPermissions.remove(key);
+        }
+
+        public Set<PermissionDetails> list(AuthenticatedUser performer,
+                                           Set<Permission> permissions,
+                                           IResource resource,
+                                           RoleResource grantee) throws RequestValidationException, RequestExecutionException
+        {
+            Pair<String, IResource> key = Pair.create(grantee.getRoleName(), resource);
+            Set<Permission> perms = userPermissions.get(key);
+            if (perms == null)
+                return Collections.emptySet();
+
+
+            Set<PermissionDetails> details = new HashSet<>();
+            for (Permission permission : perms)
+            {
+                if (permissions.contains(permission))
+                    details.add(new PermissionDetails(grantee.getRoleName(), resource, permission));
+            }
+            return details;
+        }
+
+        public void revokeAllFrom(RoleResource revokee)
+        {
+            for (Pair<String, IResource> key : userPermissions.keySet())
+                if (key.left.equals(revokee.getRoleName()))
+                    userPermissions.remove(key);
+        }
+
+        public void revokeAllOn(IResource droppedResource)
+        {
+            for (Pair<String, IResource> key : userPermissions.keySet())
+                if (key.right.equals(droppedResource))
+                    userPermissions.remove(key);
+
+        }
+
+        public Set<? extends IResource> protectedResources()
+        {
+            return Collections.emptySet();
+        }
+
+        public void validateConfiguration() throws ConfigurationException
+        {
+
+        }
+
+        public void setup()
+        {
+
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/cassandra/blob/01115f72/test/unit/org/apache/cassandra/cql3/validation/entities/UFIdentificationTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/validation/entities/UFIdentificationTest.java b/test/unit/org/apache/cassandra/cql3/validation/entities/UFIdentificationTest.java
new file mode 100644
index 0000000..28b8afc
--- /dev/null
+++ b/test/unit/org/apache/cassandra/cql3/validation/entities/UFIdentificationTest.java
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.cql3.validation.entities;
+
+import java.util.*;
+
+import com.google.common.base.Joiner;
+import com.google.common.collect.Iterables;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import org.apache.cassandra.cql3.Attributes;
+import org.apache.cassandra.cql3.CQLStatement;
+import org.apache.cassandra.cql3.QueryProcessor;
+import org.apache.cassandra.cql3.functions.Function;
+import org.apache.cassandra.cql3.statements.BatchStatement;
+import org.apache.cassandra.cql3.statements.ModificationStatement;
+import org.apache.cassandra.cql3.CQLTester;
+import org.apache.cassandra.service.ClientState;
+
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Checks the collection of Function objects returned by CQLStatement.getFunction
+ * matches expectations. This is intended to verify the various subcomponents of
+ * the statement (Operations, Terms, Restrictions, RestrictionSet, Selection,
+ * Selector, SelectorFactories etc) properly report any constituent functions.
+ * Some purely terminal functions are resolved at preparation, so those are not
+ * included in the reported list. They still need to be surveyed, to verify the
+ * calling client has the necessary permissions. UFAuthTest includes tests which
+ * verify this more thoroughly than we can here.
+ */
+public class UFIdentificationTest extends CQLTester
+{
+    private com.google.common.base.Function<Function, String> toFunctionNames = new com.google.common.base.Function<Function, String>()
+    {
+        public String apply(Function f)
+        {
+            return f.name().keyspace + "." + f.name().name;
+        }
+    };
+
+    String tFunc;
+    String iFunc;
+    String lFunc;
+    String sFunc;
+    String mFunc;
+    String uFunc;
+    String udtFunc;
+
+    String userType;
+
+    @Before
+    public void setup() throws Throwable
+    {
+        userType = KEYSPACE + "." + createType("CREATE TYPE %s (t text, i int)");
+
+        createTable("CREATE TABLE %s (" +
+                    "   key int, " +
+                    "   t_sc text STATIC," +
+                    "   i_cc int, " +
+                    "   t_cc text, " +
+                    "   i_val int," +
+                    "   l_val list<int>," +
+                    "   s_val set<int>," +
+                    "   m_val map<int, int>," +
+                    "   u_val timeuuid," +
+                    "   udt_val frozen<" + userType + ">," +
+                    "   PRIMARY KEY (key, i_cc, t_cc)" +
+                    ")");
+
+        tFunc = createEchoFunction("text");
+        iFunc = createEchoFunction("int");
+        lFunc = createEchoFunction("list<int>");
+        sFunc = createEchoFunction("set<int>");
+        mFunc = createEchoFunction("map<int, int>");
+        uFunc = createEchoFunction("timeuuid");
+        udtFunc = createEchoFunction(userType);
+    }
+
+    @Test
+    public void testSimpleModificationStatement() throws Throwable
+    {
+        assertFunctions(cql("INSERT INTO %s (key, t_sc) VALUES (0, %s)", functionCall(tFunc, "'foo'")), tFunc);
+        assertFunctions(cql("INSERT INTO %s (key, i_cc) VALUES (0, %s)", functionCall(iFunc, "1")), iFunc);
+        assertFunctions(cql("INSERT INTO %s (key, t_cc) VALUES (0, %s)", functionCall(tFunc, "'foo'")), tFunc);
+        assertFunctions(cql("INSERT INTO %s (key, i_val) VALUES (0, %s)", functionCall(iFunc, "1")), iFunc);
+        assertFunctions(cql("INSERT INTO %s (key, l_val) VALUES (0, %s)", functionCall(lFunc, "[1]")), lFunc);
+        assertFunctions(cql("INSERT INTO %s (key, s_val) VALUES (0, %s)", functionCall(sFunc, "{1}")), sFunc);
+        assertFunctions(cql("INSERT INTO %s (key, m_val) VALUES (0, %s)", functionCall(mFunc, "{1:1}")), mFunc);
+        assertFunctions(cql("INSERT INTO %s (key, udt_val) VALUES (0,%s)", functionCall(udtFunc, "{i : 1, t : 'foo'}")), udtFunc);
+        assertFunctions(cql("INSERT INTO %s (key, u_val) VALUES (0, %s)", functionCall(uFunc, "now()")), uFunc, "system.now");
+    }
+
+    @Test
+    public void testNonTerminalCollectionLiterals() throws Throwable
+    {
+        String iFunc2 = createEchoFunction("int");
+        String mapValue = String.format("{%s:%s}", functionCall(iFunc, "1"), functionCall(iFunc2, "1"));
+        assertFunctions(cql("INSERT INTO %s (key, m_val) VALUES (0, %s)", mapValue), iFunc, iFunc2);
+
+        String listValue = String.format("[%s]", functionCall(iFunc, "1"));
+        assertFunctions(cql("INSERT INTO %s (key, l_val) VALUES (0, %s)", listValue), iFunc);
+
+        String setValue = String.format("{%s}", functionCall(iFunc, "1"));
+        assertFunctions(cql("INSERT INTO %s (key, s_val) VALUES (0, %s)", setValue), iFunc);
+    }
+
+    @Test
+    public void testNonTerminalUDTLiterals() throws Throwable
+    {
+        String udtValue = String.format("{ i: %s, t : %s } ", functionCall(iFunc, "1"), functionCall(tFunc, "'foo'"));
+        assertFunctions(cql("INSERT INTO %s (key, udt_val) VALUES (0, %s)", udtValue), iFunc, tFunc);
+    }
+
+    @Test
+    public void testModificationStatementWithConditions() throws Throwable
+    {
+        assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF t_sc=%s", functionCall(tFunc, "'foo'")), tFunc);
+        assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF i_val=%s", functionCall(iFunc, "1")), iFunc);
+        assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF l_val=%s", functionCall(lFunc, "[1]")), lFunc);
+        assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF s_val=%s", functionCall(sFunc, "{1}")), sFunc);
+        assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF m_val=%s", functionCall(mFunc, "{1:1}")), mFunc);
+
+
+        String iFunc2 = createEchoFunction("int");
+        assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF i_val IN (%s, %S)",
+                            functionCall(iFunc, "1"),
+                            functionCall(iFunc2, "2")),
+                        iFunc, iFunc2);
+
+        assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF u_val=%s",
+                            functionCall(uFunc, "now()")),
+                        uFunc, "system.now");
+
+        // conditions on collection elements
+        assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF l_val[%s] = %s",
+                            functionCall(iFunc, "1"),
+                            functionCall(iFunc2, "1")),
+                        iFunc, iFunc2);
+        assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF m_val[%s] = %s",
+                            functionCall(iFunc, "1"),
+                            functionCall(iFunc2, "1")),
+                        iFunc, iFunc2);
+    }
+
+    @Test @Ignore
+    // Technically, attributes like timestamp and ttl are Terms so could potentially
+    // resolve to function calls (& so you can call getFunctions on them)
+    // However, this is currently disallowed by CQL syntax
+    public void testModificationStatementWithAttributesFromFunction() throws Throwable
+    {
+        String longFunc = createEchoFunction("bigint");
+        assertFunctions(cql("INSERT INTO %s (key, i_cc, t_cc, i_val) VALUES (0, 0, 'foo', 0) USING TIMESTAMP %s",
+                            functionCall(longFunc, "9999")),
+                        longFunc);
+
+        assertFunctions(cql("INSERT INTO %s (key, i_cc, t_cc, i_val) VALUES (0, 0, 'foo', 0) USING TTL %s",
+                            functionCall(iFunc, "8888")),
+                        iFunc);
+
+        assertFunctions(cql("INSERT INTO %s (key, i_cc, t_cc, i_val) VALUES (0, 0, 'foo', 0) USING TIMESTAMP %s AND TTL %s",
+                            functionCall(longFunc, "9999"), functionCall(iFunc, "8888")),
+                        longFunc, iFunc);
+    }
+
+    @Test
+    public void testModificationStatementWithNestedFunctions() throws Throwable
+    {
+        String iFunc2 = createEchoFunction("int");
+        String iFunc3 = createEchoFunction("int");
+        String iFunc4 = createEchoFunction("int");
+        String iFunc5 = createEchoFunction("int");
+        String iFunc6 = createEchoFunction("int");
+        String nestedFunctionCall = nestedFunctionCall(iFunc6, iFunc5,
+                                                       nestedFunctionCall(iFunc4, iFunc3,
+                                                                          nestedFunctionCall(iFunc2, iFunc, "1")));
+
+        assertFunctions(cql("DELETE FROM %s WHERE key=%s", nestedFunctionCall),
+                        iFunc, iFunc2, iFunc3, iFunc4, iFunc5, iFunc6);
+    }
+
+    @Test
+    public void testSelectStatementSimpleRestrictions() throws Throwable
+    {
+        assertFunctions(cql("SELECT i_val FROM %s WHERE key=%s", functionCall(iFunc, "1")), iFunc);
+        assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND t_sc=%s", functionCall(tFunc, "'foo'")), tFunc);
+        assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND i_cc=%s AND t_cc='foo'", functionCall(iFunc, "1")), iFunc);
+        assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND i_cc=0 AND t_cc=%s", functionCall(tFunc, "'foo'")), tFunc);
+
+        String iFunc2 = createEchoFunction("int");
+        String tFunc2 = createEchoFunction("text");
+        assertFunctions(cql("SELECT i_val FROM %s WHERE key=%s AND t_sc=%s AND i_cc=%s AND t_cc=%s",
+                            functionCall(iFunc, "1"),
+                            functionCall(tFunc, "'foo'"),
+                            functionCall(iFunc2, "1"),
+                            functionCall(tFunc2, "'foo'")),
+                        iFunc, tFunc, iFunc2, tFunc2);
+    }
+
+    @Test
+    public void testSelectStatementRestrictionsWithNestedFunctions() throws Throwable
+    {
+        String iFunc2 = createEchoFunction("int");
+        String iFunc3 = createEchoFunction("int");
+        String iFunc4 = createEchoFunction("int");
+        String iFunc5 = createEchoFunction("int");
+        String iFunc6 = createEchoFunction("int");
+        String nestedFunctionCall = nestedFunctionCall(iFunc6, iFunc5,
+                                                       nestedFunctionCall(iFunc3, iFunc4,
+                                                                          nestedFunctionCall(iFunc, iFunc2, "1")));
+
+        assertFunctions(cql("SELECT i_val FROM %s WHERE key=%s", nestedFunctionCall),
+                        iFunc, iFunc2, iFunc3, iFunc4, iFunc5, iFunc6);
+    }
+
+    @Test
+    public void testNonTerminalTupleInSelectRestrictions() throws Throwable
+    {
+        assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND (i_cc, t_cc) IN ((%s, %s))",
+                            functionCall(iFunc, "1"),
+                            functionCall(tFunc, "'foo'")),
+                        iFunc, tFunc);
+
+        assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND (i_cc, t_cc) = (%s, %s)",
+                            functionCall(iFunc, "1"),
+                            functionCall(tFunc, "'foo'")),
+                        iFunc, tFunc);
+
+        assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND (i_cc, t_cc) > (%s, %s)",
+                            functionCall(iFunc, "1"),
+                            functionCall(tFunc, "'foo'")),
+                        iFunc, tFunc);
+
+        assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND (i_cc, t_cc) < (%s, %s)",
+                            functionCall(iFunc, "1"),
+                            functionCall(tFunc, "'foo'")),
+                        iFunc, tFunc);
+
+         assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND (i_cc, t_cc) > (%s, %s) AND (i_cc, t_cc) < (%s, %s)",
+                            functionCall(iFunc, "1"),
+                            functionCall(tFunc, "'foo'"),
+                            functionCall(iFunc, "1"),
+                            functionCall(tFunc, "'foo'")),
+                         iFunc, tFunc);
+    }
+
+    @Test
+    public void testNestedFunctionInTokenRestriction() throws Throwable
+    {
+        String iFunc2 = createEchoFunction("int");
+        assertFunctions(cql("SELECT i_val FROM %s WHERE token(key) = token(%s)", functionCall(iFunc, "1")),
+                        "system.token", iFunc);
+        assertFunctions(cql("SELECT i_val FROM %s WHERE token(key) > token(%s)", functionCall(iFunc, "1")),
+                        "system.token", iFunc);
+        assertFunctions(cql("SELECT i_val FROM %s WHERE token(key) < token(%s)", functionCall(iFunc, "1")),
+                        "system.token", iFunc);
+        assertFunctions(cql("SELECT i_val FROM %s WHERE token(key) > token(%s) AND token(key) < token(%s)",
+                            functionCall(iFunc, "1"),
+                            functionCall(iFunc2, "1")),
+                        "system.token", iFunc, iFunc2);
+    }
+
+    @Test
+    public void testSelectStatementSimpleSelections() throws Throwable
+    {
+        String iFunc2 = createEchoFunction("int");
+        execute("INSERT INTO %s (key, i_cc, t_cc, i_val) VALUES (0, 0, 'foo', 0)");
+        assertFunctions(cql2("SELECT i_val, %s FROM %s WHERE key=0", functionCall(iFunc, "i_val")), iFunc);
+        assertFunctions(cql2("SELECT i_val, %s FROM %s WHERE key=0", nestedFunctionCall(iFunc, iFunc2, "i_val")), iFunc, iFunc2);
+    }
+
+    @Test
+    public void testSelectStatementNestedSelections() throws Throwable
+    {
+        String iFunc2 = createEchoFunction("int");
+        execute("INSERT INTO %s (key, i_cc, t_cc, i_val) VALUES (0, 0, 'foo', 0)");
+        assertFunctions(cql2("SELECT i_val, %s FROM %s WHERE key=0", functionCall(iFunc, "i_val")), iFunc);
+        assertFunctions(cql2("SELECT i_val, %s FROM %s WHERE key=0", nestedFunctionCall(iFunc, iFunc2, "i_val")), iFunc, iFunc2);
+    }
+
+    @Test
+    public void testBatchStatement() throws Throwable
+    {
+        String iFunc2 = createEchoFunction("int");
+        List<ModificationStatement> statements = new ArrayList<>();
+        statements.add(modificationStatement(cql("INSERT INTO %s (key, i_cc, t_cc) VALUES (%s, 0, 'foo')",
+                                                 functionCall(iFunc, "0"))));
+        statements.add(modificationStatement(cql("INSERT INTO %s (key, i_cc, t_cc) VALUES (1, %s, 'foo')",
+                                                 functionCall(iFunc2, "1"))));
+        statements.add(modificationStatement(cql("INSERT INTO %s (key, i_cc, t_cc) VALUES (2, 2, %s)",
+                                                 functionCall(tFunc, "'foo'"))));
+
+        BatchStatement batch = new BatchStatement(-1, BatchStatement.Type.LOGGED, statements, Attributes.none());
+        assertFunctions(batch, iFunc, iFunc2, tFunc);
+    }
+
+    @Test
+    public void testBatchStatementWithConditions() throws Throwable
+    {
+        List<ModificationStatement> statements = new ArrayList<>();
+        statements.add(modificationStatement(cql("UPDATE %s SET i_val = %s WHERE key=0 AND i_cc=0 and t_cc='foo' IF l_val = %s",
+                                                 functionCall(iFunc, "0"), functionCall(lFunc, "[1]"))));
+        statements.add(modificationStatement(cql("UPDATE %s SET i_val = %s WHERE key=0 AND i_cc=1 and t_cc='foo' IF s_val = %s",
+                                                 functionCall(iFunc, "0"), functionCall(sFunc, "{1}"))));
+
+        BatchStatement batch = new BatchStatement(-1, BatchStatement.Type.LOGGED, statements, Attributes.none());
+        assertFunctions(batch, iFunc, lFunc, sFunc);
+    }
+
+    private ModificationStatement modificationStatement(String cql)
+    {
+        return (ModificationStatement) QueryProcessor.getStatement(cql, ClientState.forInternalCalls()).statement;
+    }
+
+    private void assertFunctions(String cql, String... function)
+    {
+        CQLStatement stmt = QueryProcessor.getStatement(cql, ClientState.forInternalCalls()).statement;
+        assertFunctions(stmt, function);
+    }
+
+    private void assertFunctions(CQLStatement stmt, String... function)
+    {
+        Set<String> expected = com.google.common.collect.Sets.newHashSet(function);
+        Set<String> actual = com.google.common.collect.Sets.newHashSet(Iterables.transform(stmt.getFunctions(),
+                                                                                           toFunctionNames));
+        assertTrue(com.google.common.collect.Sets.symmetricDifference(expected, actual).isEmpty());
+    }
+
+    private String cql(String template, String... params)
+    {
+        String tableName = KEYSPACE + "." + currentTable();
+        return String.format(template, com.google.common.collect.Lists.asList(tableName, params).toArray());
+    }
+
+    // Alternative query builder - appends the table name to the supplied params,
+    // for stmts of the form "SELECT x, %s FROM %s WHERE y=0"
+    private String cql2(String template, String... params)
+    {
+        Object[] args = Arrays.copyOf(params, params.length + 1);
+        args[params.length] = KEYSPACE + "." + currentTable();
+        return String.format(template, args);
+    }
+
+    private String functionCall(String fName, String... args)
+    {
+        return String.format("%s(%s)", fName, Joiner.on(",").join(args));
+    }
+
+    private String nestedFunctionCall(String outer, String inner, String innerArgs)
+    {
+        return functionCall(outer, functionCall(inner, innerArgs));
+    }
+
+    private String createEchoFunction(String type) throws Throwable
+    {
+        return createFunction(KEYSPACE, type,
+           "CREATE FUNCTION %s(input " + type + ")" +
+           " CALLED ON NULL INPUT" +
+           " RETURNS " + type +
+           " LANGUAGE java" +
+           " AS ' return input;'");
+    }
+}


Mime
View raw message